import torch, torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys,random,time,os
from copy import deepcopy 
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
import torchvision.transforms as T

sys.path.append("./third_party/co-tracker/")
from cotracker.utils.visualizer import Visualizer, read_video_from_path

import geometry

import torch_kmeans
from torch_kmeans import KMeans

# Lambda helpers
ch_sec    = lambda x: rearrange(x,"... c x y -> ... (x y) c")
ch_fst    = lambda src,x=None:rearrange(src,"... (x y) c -> ... c x y",x=int(src.size(-2)**(.5)) if x is None else x)
hom       = lambda x: torch.cat((x,torch.ones_like(x[...,[0]])),-1)
unhom     = lambda x: x[...,:-1]/(1e-5+x[...,-1:])
grid_samp_= lambda x,y: F.grid_sample(x,y*2-1,mode="bilinear",padding_mode="border") # assumes y in [0,1] and moves to [-1,1]
grid_samp = lambda x,y: grid_samp_(x,y) if len(x.shape)==4 else grid_samp_(x.flatten(0,1),y.flatten(0,1)).unflatten(0,x.shape[:2]) # todo use more general flatten recipe
project   = lambda crds,K: unhom(torch.einsum("b...cij,b...ckj->b...cki",K, crds))
warp      = lambda crds,poses,K: project( torch.einsum("b...cij,b...ckj->b...cki",poses,hom(crds))[...,:3], K )

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, 'weight'):
                nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
    layers = []
    for i in range(len(dims)-1):
        layers.append(nn.Linear(dims[i],dims[i+1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

class FlowMap(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args=args

        # Our actual model
        self.time_stride = args.time_stride
        fdim=32
        self.depth_est = make_net([fdim,fdim,1])
        self.resnet_enc=nn.Sequential(PixelNeRFEncoder(in_ch=3*self.time_stride,use_first_pool=True),nn.Conv2d(512,fdim*self.time_stride,3,padding=1))
        self.depth_conv=nn.Conv2d(fdim,1,3,padding=1)
        affinity_dim = 8
        self.affinities_conv=nn.Conv2d(fdim,affinity_dim,3,padding=1)
        self.general_confidence_conv=nn.Conv2d(fdim,1,3,padding=1)

    def forward_allpts(self, model_input, sample_pts=None, out={}): # run forward pass over all point tracks

        track_idxs_all = torch.arange(model_input["pred_tracks"].size(-2))
        outs=[]
        # Run model over all track queries iteratively
        #from pdb import set_trace as pdb_;pdb_() 
        print("collecting perpoint queries")
        with torch.no_grad():
            #for i,track_idxs in enumerate(tqdm(track_idxs_all.chunk(len(track_idxs_all)//300),leave=False,desc="collecting perpoint queries")): 
            for i,track_idxs in enumerate(track_idxs_all.chunk(len(track_idxs_all)//300)): 
                outs.append( self(model_input,track_idxs=track_idxs) )
        print("done collecting perpoint queries")
        # Aggregate outputs
        out=outs[0]
        out["point_track_reproj"] = torch.cat([ x["point_track_reproj"] for x in outs ], 2)
        out["aff_sim"] = torch.cat([ x["aff_sim"] for x in outs ], 1)
        out["poses_all"] = torch.cat([ x["poses_all"] for x in outs ], 0)
        out["worldcrds_pertrack"] = torch.cat([ x["worldcrds_pertrack"] for x in outs ],1)

        # cluster poses
        out["pose_clusters"] = geometry.cluster_and_represent(out["poses_all"],n_clusters=15)
        return out

    def forward(self, model_input, track_idxs=None, out={}): # point track based, solving for pose per point, given n points to track for

        imsize=model_input["rgb"].shape[-2:]
        (b,_),n_trgt=model_input["rgb"].shape[:2],model_input["rgb"].size(1)
        n_samp=imsize[0]*imsize[1]#min(40000,imsize[0]*imsize[1])
        rand_subset = torch.linspace(0,imsize[0]*imsize[1]-1,n_samp).long()
        low_imres=(64,64)
        depth_inp = ch_fst(model_input["depth_inp"],imsize[0])

        static_solve=False

        # sample rigid flow masks per track and aggregate over time (if ever dynamic then dynamic) to get an estimate if a track is rigid
        rig_samp = grid_samp(model_input["rig_flow_masks"][:,:,[0]],model_input["pred_tracks"][:,1:].unsqueeze(-2)).squeeze(2).squeeze(-1).round()
        rig_samp = torch.where(model_input["pred_visibility"][:,1:],rig_samp,torch.ones_like(rig_samp))
        rig_samp = rig_samp.min(dim=1)[0]

        rig_samp = rig_samp*0+1;print("rig sanity")
        
        #import matplotlib.pyplot as plt 
        #plt.imsave("/home/cameronsmith/tmp.png",rearrange(rig_samp,"1 (x y s) -> x (s y)",x=64,y=64).cpu().numpy())
        #rgb_pertrack = ch_sec(grid_samp(model_input["rgb"],model_input["pred_tracks"].unsqueeze(-2)))
        #rgb_pertrack = (rgb_pertrack*model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1)/model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip(min=1)
        #plt.imsave("/home/cameronsmith/tmp2.png",rearrange(rgb_pertrack,"1 (x y s) c -> x (s y) c",x=64,y=64).cpu().clip(-1,1).numpy()*.5+.5)
        #from pdb import set_trace as pdb_;pdb_() 

        # pick n random points if not provided
        if track_idxs is None: track_idxs = torch.randperm(model_input["pred_tracks"].size(-2))[:1 if static_solve else 300]

        # General feature map backbone prediction
        img_inp=rearrange(model_input["rgb"],"b (t s) c x y -> b t (s c) x y",s=self.time_stride)
        if "fmap" not in model_input or torch.is_grad_enabled():
            fmap_out = F.interpolate(self.resnet_enc(img_inp.flatten(0,1)*.5+.5),imsize,mode="bilinear")
            model_input["fmap"] = rearrange(fmap_out,"(b t) (s c) x y -> b (t s) c x y",s=self.time_stride,b=b)
        
        res_depth = F.softplus(self.depth_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))+1)
        depth = depth_inp + res_depth
        #print("no depth test")

        # Lift point track into 2.5D image-aligned surface
        rds_track = geometry.get_world_rays(model_input["pred_tracks"],model_input["intrinsics"],None)[1]
        eye_surf_track = rds_track * grid_samp(depth,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2)

        # Est affinity weights and similarities for each source point -- these are correspondence weights from each point to each other point
        affinity_emb_unnorm = self.affinities_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt))
        # bias it to be near 1's by making it a small residual from all 1s
        #affinity_emb_unnorm = affinity_emb_unnorm/5 + 1

        #affinity_emb_unnorm = model_input["dino_pca"] + affinity_emb_unnorm
        #affinity_emb_unnorm = torch.ones_like(affinity_emb_unnorm);print("sanity affinity as ones")# sanity check

        affinity_emb = F.normalize(affinity_emb_unnorm, dim=2)
        #if 1: affinity_emb = torch.ones_like(affinity_emb)/affinity_emb.size(-1);print("sanity affinity as ones")# sanity check
        aff_emb_pertrack_allframe = ch_sec(grid_samp(affinity_emb,model_input["pred_tracks"].unsqueeze(-2)))
        # Take affinity emb as mean over frames masked by visibility
        aff_emb_pertrack = (aff_emb_pertrack_allframe*model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1)/model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip(min=1)
        # Affinity of all tracks to all other tracks
        aff_sim = torch.einsum('b p c, b q c -> b p q', aff_emb_pertrack[:,track_idxs], aff_emb_pertrack) # from all source pix to all other source pix

        # Predict general correspondence weights -- how reliable is this track generally at each frame
        general_conf = self.general_confidence_conv(model_input["fmap"].flatten(0,1)).unflatten(0,(b,n_trgt)).sigmoid().clip(min=1e-4)
        general_conf_track = grid_samp(general_conf,model_input["pred_tracks"].unsqueeze(-2)).squeeze(2) * model_input["pred_visibility"].unsqueeze(-1)

        # note add switch here to use direct pose regression instead of iterative regression, suspect it will be better gradients in large-scale learning case 
        solve_stride=model_input["pred_tracks"].size(-2)//3000 # use every nth point in the solve
        direct_pose_regression=False
        #if not direct_pose_regression: # choice to use direct pose regression to canonical frame 0 vs adjacent pose regression to frame 0; latter more general for per-scene recon
        #aff_sim_rig = torch.where( rig_samp, torch.ones_like(aff_sim), aff_sim)
        aff_sim_rig=torch.where(rig_samp.bool()[:,track_idxs,None].expand(-1,-1,aff_sim.size(-1)), torch.ones_like(aff_sim), aff_sim)
        poses= geometry.efficient_procrustes( eye_surf_track[:,1:,::solve_stride].expand(aff_sim.size(1),-1,-1,-1), eye_surf_track[:,:-1,::solve_stride].expand(aff_sim.size(1),-1,-1,-1), 
                                     (general_conf_track[:,:-1,::solve_stride].expand(aff_sim.size(1),-1,-1,-1)*aff_sim_rig[:,:,::solve_stride].flatten(0,1)[:,None,:,None]).clip(min=1e-4),)[1]
        for i in range(n_trgt-1,0,-1): poses = torch.cat((poses[:,:i],poses[:,[i-1]]@poses[:,i:]),1) # aggregate adjacent poses
        #else:
        #    poses= geometry.efficient_procrustes( eye_surf_track[:,1:,::solve_stride].expand(aff_sim.size(1),-1,-1,-1), eye_surf_track[:,[0],::solve_stride].expand(aff_sim.size(1),n_trgt-1,-1,-1), 
        #                                     (general_conf_track[:,:-1,::solve_stride].expand(aff_sim.size(1),-1,-1,-1)*aff_sim[:,:,::solve_stride].flatten(0,1)[:,None,:,None]).clip(min=1e-4),)[1]
        poses = torch.cat((torch.eye(4).to(poses)[None,None].expand(poses.size(0),-1,-1,-1),poses),1) # add identity for starting pose

        # if static solve, just use single pose as pose for all points
        if static_solve: track_idxs = torch.arange(model_input["pred_tracks"].size(-2));poses=poses.expand(eye_surf_track.size(-2),-1,-1,-1)

        # Compute point track reprojection
        poses_all_to_all = repeat(poses.inverse(),"p t x y -> p s t x y",s=n_trgt)@repeat(poses,"p t x y -> p t s x y",s=n_trgt)
        point_track_surf_reproj = torch.einsum('pstij,stpj->stpi',poses_all_to_all,hom(repeat(eye_surf_track[:,:,track_idxs].squeeze(0),"t p c -> t s p c",s=n_trgt)))[None,...,:3]
        point_track_reproj = project(point_track_surf_reproj,model_input["intrinsics"]).clip(0,1)
        point_track_loss = ( (point_track_reproj - model_input["pred_tracks"][:,None,:,track_idxs]) * model_input["pred_visibility"][:,None,:,track_idxs,None] ).square().mean()

        # mean color and crds per track for vis
        rgb_pertrack = ch_sec(grid_samp(model_input["rgb"],model_input["pred_tracks"].unsqueeze(-2)))
        rgb_pertrack = (rgb_pertrack*model_input["pred_visibility"].unsqueeze(-1)).sum(dim=1)/model_input["pred_visibility"].unsqueeze(-1).sum(dim=1).clip(min=1)
        worldcrds_pertrack = torch.einsum('ptij,tpj->tpi',poses,hom(eye_surf_track[0,:,track_idxs]))[...,:3]
        worldcrds_pertrack = (worldcrds_pertrack*model_input["pred_visibility"][:,:,track_idxs].unsqueeze(-1)).sum(dim=1)/model_input["pred_visibility"][:,:,track_idxs].unsqueeze(-1).sum(dim=1).clip(min=1)
        return out | {
            #"rig_masks":rearrange(rig_masks,"(b o) t xy 1 -> b t o xy 1",o=self.n_rig),
            "worldcrds_pertrack":worldcrds_pertrack,
            "rgb_pertrack":rgb_pertrack,
            "rig_pertrack":rig_samp,
            "res_depth":ch_sec(res_depth),
            "depth":ch_sec(depth),
            "poses":poses[[len(poses)//2]],
            "poses_all":poses,
            "depth_inp":model_input["depth_inp"],
            "point_track_loss":point_track_loss,
            "point_track_reproj": point_track_reproj[:,0],
            "corr_weights": general_conf,
            "aff_sim": aff_sim,
            "affinity_emb": affinity_emb,
            "affinity_emb_unnorm": affinity_emb_unnorm,
        }

class PixelNeRFEncoder(nn.Module):
    def __init__(
        self,
        backbone="resnet34",
        pretrained=True,
        num_layers=4,
        index_interp="bilinear",
        index_padding="border",
        upsample_interp="bilinear",
        feature_scale=1.0,
        use_first_pool=True,
        norm_type="batch",
        in_ch=3,
    ):
        super().__init__()

        def get_norm_layer(norm_type="instance", group_norm_groups=32):
            """Return a normalization layer
            Parameters:
                norm_type (str) -- the name of the normalization layer: batch | instance | none
            For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
            For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
            """
            if norm_type == "batch":
                norm_layer = functools.partial(
                    nn.BatchNorm2d, affine=True, track_running_stats=True
                )
            elif norm_type == "instance":
                norm_layer = functools.partial(
                    nn.InstanceNorm2d, affine=False, track_running_stats=False
                )
            elif norm_type == "group":
                norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
            elif norm_type == "none":
                norm_layer = None
            else:
                raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
            return norm_layer


        self.feature_scale = feature_scale
        self.use_first_pool = use_first_pool
        norm_layer = get_norm_layer(norm_type)

        print("Using torchvision", backbone, "encoder")
        self.model = getattr(torchvision.models, backbone)(
            pretrained=pretrained, norm_layer=norm_layer
        )

        if in_ch != 3:
            self.model.conv1 = nn.Conv2d(
                in_ch,
                self.model.conv1.weight.shape[0],
                self.model.conv1.kernel_size,
                self.model.conv1.stride,
                self.model.conv1.padding,
                padding_mode=self.model.conv1.padding_mode,
            )

        # Following 2 lines need to be uncommented for older configs
        self.model.fc = nn.Sequential()
        self.model.avgpool = nn.Sequential()
        self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]

        self.num_layers = num_layers
        self.index_interp = index_interp
        self.index_padding = index_padding
        self.upsample_interp = upsample_interp
        self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
        self.register_buffer(
            "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
        )

        self.out = nn.Sequential(
            nn.Conv2d(self.latent_size, 512, 1),
        )

    def forward(self, x, custom_size=None):


        if len(x.shape)>4: return self(x.flatten(0,1),custom_size).unflatten(0,x.shape[:2])

        if self.feature_scale != 1.0:
            x = F.interpolate(
                x,
                scale_factor=self.feature_scale,
                mode="bilinear" if self.feature_scale > 1.0 else "area",
                align_corners=True if self.feature_scale > 1.0 else None,
                recompute_scale_factor=True,
            )
        x = x.to(device=self.latent.device)
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)

        latents = [x]
        if self.num_layers > 1:
            if self.use_first_pool:
                x = self.model.maxpool(x)
            x = self.model.layer1(x)
            latents.append(x)
        if self.num_layers > 2:
            x = self.model.layer2(x)
            latents.append(x)
        if self.num_layers > 3:
            x = self.model.layer3(x)
            latents.append(x)
        if self.num_layers > 4:
            x = self.model.layer4(x)
            latents.append(x)

        self.latents = latents
        align_corners = None if self.index_interp == "nearest " else True
        latent_sz = latents[0].shape[-2:]
        for i in range(len(latents)):
            latents[i] = F.interpolate(
                latents[i],
                latent_sz if custom_size is None else custom_size,
                mode=self.upsample_interp,
                align_corners=align_corners,
            )
        self.latent = torch.cat(latents, dim=1)
        self.latent_scaling[0] = self.latent.shape[-1]
        self.latent_scaling[1] = self.latent.shape[-2]
        self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
        return self.out(self.latent)
