import numpy as np
import torch


def check_broadcastable(x, y):
    assert len(x.shape) == len(y.shape)
    for n, m in zip(x.shape[:-1], y.shape[:-1]):
        assert n == m or n == 1 or m == 1


def broadcast_inputs(x, y):
    """Automatic broadcasting of missing dimensions"""
    if y is None:
        xs, xd = x.shape[:-1], x.shape[-1]
        return (x.view(-1, xd).contiguous(),), x.shape[:-1]

    check_broadcastable(x, y)

    xs, xd = x.shape[:-1], x.shape[-1]
    ys, yd = y.shape[:-1], y.shape[-1]
    out_shape = [max(n, m) for (n, m) in zip(xs, ys)]

    if x.shape[:-1] == y.shape[-1]:
        x1 = x.view(-1, xd)
        y1 = y.view(-1, yd)

    else:
        x_expand = [m if n == 1 else 1 for (n, m) in zip(xs, ys)]
        y_expand = [n if m == 1 else 1 for (n, m) in zip(xs, ys)]
        x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
        y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()

    return (x1, y1), tuple(out_shape)
