import torch
import numpy as np
import torch.nn.functional as F
from einops import rearrange
ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
def make_sample(sample,aspect,budget=192*640/4,hires_factor=2,med_factor=1,low_res=None,hi_res=None):
    
    y=np.sqrt(budget/aspect)
    x=budget/y
    low_res_=[int(y),int(x)]
    mult32=lambda x:x-(x%32)+32
    if low_res is None: low_res=[mult32(x) for x in low_res_]
    if hi_res is None: hi_res=[mult32(int(hires_factor*x)) for x in low_res_]
    med_res=[mult32(int(med_factor*x)) for x in low_res_]

    #print("making sample")
    uv = np.mgrid[0 : low_res[0], 0 : low_res[1]].astype(float).transpose(1, 2, 0)
    uv = torch.from_numpy(np.flip(uv, axis=-1).copy()).long()
    uv = uv / torch.tensor([low_res[1]-1, low_res[0]-1])  # uv in [0,1]

    uv_hires = np.mgrid[0 : hi_res[0], 0 : hi_res[1]].astype(float).transpose(1, 2, 0)
    uv_hires = torch.from_numpy(np.flip(uv_hires, axis=-1).copy()).long()
    uv_hires = uv_hires / torch.tensor([hi_res[1]-1, hi_res[0]-1])  # uv in [0,1]

    model_input,gt={},{}
    model_input["rgb"]= F.interpolate(sample["rgb"],low_res,antialias=True,mode="bilinear")
    #model_input["rgb_med"]= F.interpolate(sample["rgb"]*.5+.5,med_res,antialias=True,mode="bilinear")
    if 0:
        print("using lower high res, todo make arg")
        hi_res=[low_res[0]*2,low_res[1]*2]

    model_input["rgb_large"]= F.interpolate(sample["rgb"]*.5+.5,hi_res,antialias=True,mode="bilinear")*255
    model_input["x_pix"]=uv[None].flatten(1,2).expand(len(model_input["rgb"]),-1,-1)
    model_input["x_pix_large"]=uv_hires[None].flatten(1,2).expand(len(model_input["rgb"]),-1,-1)
    gt["rgb"]=ch_sec(model_input["rgb"])*.5+.5

    if "intrinsics" in sample: model_input["gt_intrinsics"]=model_input["intrinsics"]=sample["intrinsics"]
    if "depth_inp" in sample:
        gt["depth_inp"]=model_input["depth_inp"]= ch_sec(F.interpolate(sample["depth_inp"][:,None],low_res))
        model_input["depth_inp_large"]= ch_sec(F.interpolate(sample["depth_inp"][:,None],hi_res))
    #for k,v in sample.items(): 
    #    if k not in model_input and k not in gt: model_input[k]=v
    if "c2w" in sample: model_input["c2w"]=sample["c2w"]
    if "org_ratio" in sample: model_input["org_ratio"]=sample["org_ratio"]
    #print("done making sample")
    return model_input,gt
