# code based on github snippet https://gist.github.com/rossant/6046463
import geometry
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
#import vis_scene_graph

#w,h = [64]*2
device="cuda" if torch.cuda.is_available() else "cpu"

seed=46
torch.manual_seed(seed) #todo for fixed set of scenes just give each scene a seed
np.random.seed(seed)
base_cam=torch.eye(4)[None]

def normalize(x): return torch.nn.functional.normalize(x,dim=-1)

def look_at(cam_pos, look_at, up):
    # Step 1: Compute forward vector
    forward = normalize(look_at - cam_pos)
    
    # Step 2: Compute right vector
    right = normalize(torch.cross(up, forward))
    
    # Step 3: Compute true up vector
    up = torch.cross(forward, right)
    
    # Step 4: Construct the rotation matrix
    rotation = torch.stack([right, up, forward], dim=1)
    
    # Step 5: Construct the view matrix
    translation = -torch.matmul(rotation, cam_pos)
    view_matrix = torch.eye(4)
    view_matrix[:3, :3] = rotation
    view_matrix[:3, 3] = translation
    
    return view_matrix

def intersect_plane(O, D, P, N):
    # Return the distance from O to the intersection of the ray (O, D) with the 
    # plane (P, N), or +inf if there is no intersection.
    # O and P are 3D points, D and N (normal) are normalized vectors.
    denom = (D*N).sum(-1)
    d= ((P - O)*N).sum(-1) / denom
    return torch.where(d<0,torch.ones_like(d)*10000,d)

def intersect_sphere(O, D, S, R):
    # Return the distance from O to the intersection of the ray (O, D) with the 
    # sphere (S, R), or +inf if there is no intersection.
    # O and S are 3D points, D (direction) is a normalized vector, R is a scalar.
    a = (D*D).sum(dim=-1)+1e-5
    OS = O - S
    b = 2 * (D*OS).sum(dim=-1)
    c = (OS*OS).sum(dim=-1) - R * R
    disc = b * b - 4 * a * c
    q = torch.where(b<0, (-b - disc.sqrt()) / 2.0, (-b + disc.sqrt()) / 2.0)
    t0 = q / a
    t1 = c / q
    t0, t1 = torch.minimum(t0, t1), torch.minimum(t0, t1)
    
    return torch.where(disc>0, 
        torch.where(t1>=0, torch.where(t0<0,t1,t0), torch.inf)
       , torch.inf)

def intersect(O, D, obj):
    if obj['type'] == 'plane':
        return intersect_plane(O, D, obj['position'], obj['normal'])
    elif obj['type'] == 'sphere':
        return intersect_sphere(O, D, obj['position'], obj['radius'])

def get_normal(obj, M):
    # Find normal.
    if obj['type'] == 'sphere':
        N = normalize(M - obj['position'])
    elif obj['type'] == 'plane':
        N = obj['normal']
    return N
    
def get_color(obj, M):
    color = obj['color']
    #if not hasattr(color, '__len__'): color = color(M)
    try: return obj["color_fn"](M)
    except: return color
    return color.to(device)

def add_sphere(position, radius, color):
    return dict(type='sphere', position=torch.tensor(position).to(device), 
        radius=torch.tensor(radius).to(device), color=torch.tensor(color).to(device), reflection=.5)
    
def add_plane(position, normal,color):
    return dict(type='plane', position=torch.tensor(position).to(device), 
        normal=torch.tensor(normal).to(device),
        color_fn=lambda M: torch.where(((M[...,[0]] * 2).int() % 2) == ((M[...,[2]] * 2).int() % 2),color.to(device),torch.ones_like(color).to(device)),
        color=color,
        #color=lambda M: torch.where((int(M[...,0] * 2) % 2) == (int(M[...,2] * 2) % 2),color_plane0,color_plane1),
        #color=lambda M: (color_plane0 if (int(M[0] * 2) % 2) == (int(M[2] * 2) % 2) else color_plane1),
        diffuse_c=.75, specular_c=.5, reflection=.25)
    
# List of objects.

# Light position and color.
lh=5
lr=5
Ls = [torch.tensor([lr, lh, lr]).to(device),torch.tensor([-lr, lh, -lr]).to(device),torch.tensor([lr, lh, -lr]).to(device),torch.tensor([-lr, lh, lr]).to(device)]
color_light = torch.ones(3)/len(Ls)

# Default light and material parameters.
ambient = .05
diffuse_c = .5
specular_c = 1.
specular_k = 50

depth_max = 1  # Maximum number of light reflections.
col = torch.zeros(3)  # Current color.

#img = torch.zeros((h, w, 3))


# first parallelize, then convert from numpy to pytorch, then add custom rendering properties, then add custom objects, then randomize scenes

n_scene,n_frame,n_obj=10000,16,5
fixed_cam,gen_3d=True,True

# Rendering function
def render(frame_i,scene,cam,w=64):
    h=w

    r = float(w) / h
    # Screen coordinates: x0, y0, x1, y1.
    S = (-1., -1. / r + .25, 1., 1. / r + .25)

    rayO=O=cam.inverse()[:,:3,-1]

    uv = torch.stack(torch.meshgrid(torch.linspace(S[0],S[2],w),torch.linspace(S[1],S[3],w)),-1).to(device)#.flip([1,0])#.transpose(0,1)
    D = torch.nn.functional.normalize((torch.cat((uv,torch.zeros_like(uv[...,:1])),-1) - O),dim=-1).to(device)
    rayO,rayD=O,D

    
    #if frame_i==0: base_cam=cam

    K=torch.eye(3)[None].to(device)
    rayD=geometry.get_world_rays_(uv.flatten(0,1)[None],K,cam)[1][0].unflatten(0,(h,w))

    def trace_ray_v(rayO, rayD, obj):

        # Find first point of intersection with the scene.
        t = intersect(rayO, rayD, obj)
        M = rayO + rayD * t[...,None]

        # Find properties of the object.
        N = get_normal(obj, M)
        color = get_color(obj, M)
        toO = normalize(O - M)

        col_ray = torch.zeros(h*w,3).to(device)    

        for L in Ls:
            toL = normalize(L - M)
            # Lambert shading (diffuse).
            col_ray += obj.get('diffuse_c', diffuse_c) * torch.maximum((N*toL).sum(-1,keepdim=True), torch.zeros_like(N).to(device)) * color 
            # Blinn-Phong shading (specular).
            col_ray += obj.get('specular_c', specular_c) * torch.maximum((N*normalize(toL + toO)).sum(-1,keepdim=True), torch.zeros_like(N).to(device)) ** specular_k * color_light.to(device)
            col_ray += ambient
        return obj, M, N, col_ray,t,M

    traces=[]
    for obj in scene:
        traced = trace_ray_v(rayO, rayD.flatten(0,1),obj)
        traces.append(traced)
    ts=torch.stack([t[-2] for t in traces])
    Ms=torch.stack([t[-1] for t in traces])
    colors=torch.stack([t[-3] for t in traces])
    colors=torch.where(torch.isnan(colors),torch.zeros_like(colors),colors)
    rand_num=100

    ts_=torch.where(torch.isnan(ts)|torch.isinf(ts),torch.zeros_like(ts).to(device),ts)
    ts_=((torch.arange(len(scene)).to(device)[:,None]==ts.min(dim=0)[1])[...,None]*ts_.unsqueeze(-1)).sum(0)

    # Remove nans
    Ms=torch.where(torch.isnan(Ms)|torch.isinf(Ms)|(Ms.abs()>100),torch.zeros_like(Ms).to(device)+rand_num,Ms)
    # Composite scene
    Ms=((torch.arange(len(scene)).to(device)[:,None]==ts.min(dim=0)[1])[...,None]*Ms).sum(0)
    # Transform 3d world coordinates to camera coordinates
    Ms = (torch.cat((Ms,torch.ones_like(Ms[...,:1])),-1).T)[:3].T
    Ms = torch.where(ts_>64,Ms*1000,Ms) # sky mask
    Ms=(1/(1+Ms.view(h,w,3))).tanh() # clipping to [-1,1]

    img=((torch.arange(len(scene)).to(device)[:,None]==ts.min(dim=0)[1])[...,None]*colors).sum(0)*2-1
    seg = (ts.min(dim=0)[1]>0).float()

    exp_imgs = [x.view(h,w,-1).transpose(1,0).flip([0,1]).cpu() for x in  [1/ts_.squeeze(-1),seg,img]]

    #if frame_i==3:
    #if frame_i==0:
    #    import importlib
    #    vis_scene_graph=importlib.import_module("vis_scene_graph")
    #    vis_scene_graph.plot_3d_bounding_boxes(scene_graph,[x[0].inverse() for x in cams]) 

    #return exp_imgs+[(base_cam.inverse()@cam).squeeze(0).cpu()]
    return exp_imgs+[cam.squeeze(0).cpu()]

if __name__ == "__main__":
    # Create N scenes and render them
    for scene_i in tqdm(range(n_scene)):

        #color_planes = torch.rand(1,3).to(device)[0]#1. * torch.ones(3)
        #rand_i=torch.rand(1)>.5
        #color_planes[int(rand_i.item())]=torch.ones(3)*rand_i
        #color_planes[0]=torch.ones(3)#*rand_i
        #color_plane0,color_plane1=color_planes

        sx=1.75
        n_obj_= np.random.randint(1,n_obj+1)
        scene = [ add_plane([0., -.5, 0.], [0., 1., 0.], torch.rand(3)) ]+[ add_sphere([(torch.rand(1)*sx*2)-sx,  .1,(torch.rand(1))*2*sx], .4, torch.rand(3)) for _ in range(n_obj_)]

        # Remove spheres which intersect another 
        is_int = lambda x,y: (x["position"]-y["position"]).norm()<(x["radius"]+y["radius"])
        while any([is_int(scene[i],scene[j]) for i in range(1,n_obj_+1) for j in range(1,n_obj_+1) if i!=j]):
            for i in range(1,n_obj_+1):
                should_break=False
                for j in range(1,n_obj_+1):
                    if is_int(scene[i],scene[j]):
                        del scene[i]
                        n_obj_-=1
                        should_break=True
                        break
                if should_break:break        

        # Write scene graph
        #import vis_scene_graph
        scene_graph = [{"obj_code":torch.tensor([int(obj["type"]=="plane")]),"center":obj["position"].cpu(),"color":obj["color"],
                        "dimensions":(torch.tensor([obj["radius"]*2]*3).cpu() if obj["type"]!="plane" else torch.tensor([10,.5,10]))/10,"is_present":torch.tensor([1.0])} for obj in scene]
        scene_graph = scene_graph + [{k:v*0 for k,v in scene_graph[-1].items()} for _ in range(n_obj-n_obj_)] # pad scene graph with empty objects


        Os=[]
        for i in range(2):
            r=(torch.rand(1)+1)*2
            x,y=(torch.rand(2)*2-1)*r
            z=np.sqrt(np.abs(1-x**2-(y+2)**2)+.1)/2#.35
            if fixed_cam:print("fix cam")#x,y,z=
            Os.append( torch.tensor([x, z, y]).to(device) ) # Camera.

        #Os[1] = torch.tensor([ 0.,  0.3500, -1])
        frame_exps=[]
        cams=[]
        for frame_i,a in tqdm(enumerate(torch.linspace(0,1,n_frame)),leave=False):
            O=Os[0]*a+Os[1]*(1-a)
            lookat = torch.tensor([0., 0., 0]).to(device)  # Camera pointing to.

            # Example camera parameters
            cam_pos = O#torch.tensor([0.0, 0.0, 5.0], dtype=torch.float32)  # Camera position (x, y, z)
            look_at_point = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32).to(device)  # Look-at point (x, y, z)
            up_vector = torch.tensor([0.0, 1.0, 0.0], dtype=torch.float32).to(device)  # Up direction (x, y, z)
            cam = look_at(cam_pos, look_at_point, up_vector)[None].to(device)  
            #cams.append(cam[0].inverse())
            cams.append(cam)
            #rayOs.append(rayO)
        for frame_i,cam in tqdm(enumerate(cams),leave=False):
            frame_exps.append( render(frame_i,scene,cam,w=128) )
            #plt.imshow(frame_exps[-1][-2]);plt.show(); 
            #zz
            #if frame_i==3:zz
            #for x in frame_exps[-1][:-1]:plt.imshow(x);plt.show()
            if scene_i==0: 
                plt.imsave("vis/%02d.png"%frame_i,torch.cat((.5+.5*torch.cat(frame_exps[-1][:-1][2:]),torch.cat(frame_exps[-1][:-1][:2]).expand(-1,-1,3)),).clip(0,1).numpy());print("saving vis")
        #print(len(frame_exps))

        #print(frame_exps[-1][...,:3,-1])
        frame_exps=[torch.stack([x[i] for x in frame_exps]) for i in range(len(frame_exps[0]))]
        torch.save({k:v for k,v in zip(["invdepth","seg","rgb","cameras"],frame_exps)}|{"scene_graph":scene_graph},"dataset/scene_%04d.pt"%scene_i)
