import os,time
import torch,wandb
from tqdm import trange
from einops import rearrange
import vis,geometry
from copy import deepcopy
import numpy as np
import piqa,kornia
from torchvision.utils import make_grid
from einops import rearrange, repeat
from models import ch_sec
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import getpass
from glob import glob

import data,models

def to_gpu(ob): return {k: to_gpu(torch.tensor(v)) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

def train_flowmap(run,train_dataset,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):

    def loss_fn(model_out, gt, model_input,model,step):

        rays = lambda x,y: torch.stack([x[i,:,y[i].long()] for i in range(len(x))])
        losses = { }

        if 0:
            feature_map=model_out["affinity_emb"].flatten(0,1)
            image=model_out["imgs_premask"].flatten(0,1)*.5+.5
            gray_image = kornia.color.rgb_to_grayscale(image)  # Convert to grayscale
            edges = kornia.filters.sobel(gray_image)  # Shape (B, 1, H, W)
            edge_weights = torch.exp(-10 * edges.abs())  # Inverse edge strength
            grad_x = feature_map[:, :, :, :-1] - feature_map[:, :, :, 1:]  # Horizontal diff
            grad_y = feature_map[:, :, :-1, :] - feature_map[:, :, 1:, :]  # Vertical diff
            weight_x = edge_weights[:, :, :, :-1]  # Align with grad_x shape
            weight_y = edge_weights[:, :, :-1, :]  # Align with grad_y shape
            smoothness_loss = (weight_x * grad_x.pow(2)).mean() + (weight_y * grad_y.pow(2)).mean()
            losses["metrics/smoothness"] = smoothness_loss*1e4

        fmap=model_out["affinity_emb"].squeeze(1)
        if 0:
            losses["metrics/smoothness"]=kornia.filters.laplacian(fmap,5).abs().mean()*1e3
        #u, s, v = torch.svd(ch_sec(fmap))
        #rank_loss = torch.sum(s[:, 8:]**2)
        #losses["metrics/rank"]=rank_loss*2e-2

        if "contrastive_loss" in model_out and (step>50 or 1): losses["metrics/track_classification"] = model_out["contrastive_loss"]*1e1
        
        if step%10==0:print(losses,step)
        return losses

    losses_agg=[]
    optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())

    ##model_input = ground_truth = single_data = to_gpu(run.dataset.collate_fn([train_dataset[0][0]]))
    #single_datas = [to_gpu(run.dataset[0].collate_fn([dset[0][0]])) for dset in train_dataset]
    ##for k in single_datas[0]:
    ##    print(k)
    ##    torch.cat([single_data[k] for single_data in single_datas])
    ## Since variable length source videos have different length tracks, just pad shorter videos with zeros
    #max_tracksize=max([x["pred_tracks"].size(2) for x in single_datas])
    ### TODO unpack into grid size and add extra frames since looks weird when unpacking when zeros naively padded on
    #for x in single_datas: x["pred_tracks"]=torch.cat([x["pred_tracks"],torch.zeros_like(x["pred_tracks"][:,:,:1].expand(-1,-1,max_tracksize-x["pred_tracks"].size(2),-1))],2)
    #for x in single_datas: x["pred_visibility"]=torch.cat([x["pred_visibility"],torch.zeros_like(x["pred_visibility"][:,:,:1].expand(-1,-1,max_tracksize-x["pred_visibility"].size(2)))],2)
    #model_input_ = ground_truth_ = single_data_ = {k:torch.cat([single_data[k] for single_data in single_datas if (single_data["rgb"].shape==single_datas[0]["rgb"].shape)]) for k in single_datas[0] }

    dataset=data.ImageFolder( path="/data/cameron/monocular_ests/davis/"+run.args.imgpath, num_trgt=run.args.vid_len+1,n_skip = run.args.n_skip,sf=run.args.sf)
    #dataset=data.PointTrackFolder( path="/data/cameron/monocular_ests/davis/"+run.args.imgpath, num_trgt=run.args.vid_len+1,n_skip = run.args.n_skip,sf=run.args.sf)
    #dataset=data.ImageFolder( path="/data/cameron/monocular_ests/davis/"+run.args.imgpath, num_trgt=run.args.vid_len+1,n_skip = run.args.n_skip,sf=run.args.sf)
    #dataset=data.MultiImageFolder( num_trgt=run.args.vid_len+1,n_skip = run.args.n_skip,sf=run.args.sf)
    dataloader= iter(torch.utils.data.DataLoader(dataset, batch_size=run.args.batch_size, num_workers=min(run.args.n_workers,run.args.batch_size) if 1 else 0,shuffle=True,pin_memory=True))

    # Train loop
    step=0
    for step_ in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        #if len(model_input_["rgb"])>5: idxs=torch.randperm(len(model_input_["rgb"]))[:5]
        #else: idxs=torch.arange(len(model_input_["rgb"]))
        #model_input = ground_truth = {k:v[idxs] for k,v in model_input_.items()}

        if step_>500 or step_==0 or 1:
        #if step_==0:
            try: model_input, ground_truth = next(dataloader)
            except StopIteration:
                print("done w dataset")
                dataloader= iter(torch.utils.data.DataLoader(dataset, batch_size=run.args.batch_size, num_workers=min(4,run.args.batch_size),shuffle=True,pin_memory=True))
                continue
            #except:
            #    print("skipping bad load")
            #    continue
            model_input, ground_truth = to_gpu(model_input), to_gpu(ground_truth)

        # Run model and calculate losses
        total_loss = 0.
        #with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        out=run.model(model_input)

        losses = loss_fn(out, ground_truth, model_input,run.model,step)
        for loss_name, loss in losses.items():
            wandb.log({loss_name: loss.item()}, step=step)
            total_loss += loss
        wandb.log({"loss": total_loss.item()}, step=step)

        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print("NaN or Inf detected in loss, skipping backpropagation.")
            continue
        else:
            total_loss.backward();
            optim.step();optim.zero_grad(); 

        #run.splat_vars = geometry.format_splat_vars(out) # for splat vis
        # Image summaries and checkpoint
        if step==0:run.time=time.time()
        #losses_agg.append({k:v.detach().item() for k,v in losses.items()}|{"time elapsed (seconds)":time.time()-run.time})
        with torch.no_grad(): 
            wandb_imgs=None
            if step%until_img==0 and 1: 
                #if (step%(until_img*3)==0 or 1): out=run.model.forward_allpts(model_input) # collect predictions for all point tracks 
                wandb_imgs=vis.wandb_summary( 0, out, model_input, ground_truth, None,step=step)
        if step%until_save == 0 and step and run.args.save_model: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint.pt")) 
            #save(run.splat_vars)

        if run.args.export_poses and step%100==0: # export poses+geom if overfitting on single scene
            with torch.no_grad(): out=run.model.forward_allpts(model_input) # collect predictions for all point tracks 
            out["intrinsics"] = model_input["intrinsics"]
            if "c2w" in model_input: out["gt_poses"],out["gt_intrinsics"]=model_input["c2w"],model_input["gt_intrinsics"]
            torch.save({k:v.detach() for k,v in out.items() if type(v)==torch.Tensor and len(v.shape)>1} ,f"/data/cameron/pose_exps/poses_{run.args.name}.pt")
            #torch.save({k:v.detach() for k,v in out.items() if type(v)==torch.Tensor and len(v.shape)>1} ,f"output/pose_exps/poses_{run.args.name}.pt")
            print("exported poses")
        if run.args.export_feats and step%100==0: # manual feature export todo refactor
            torch.save(out["affinity_emb"] ,f"/data/cameron/feat_exp_exps/featexp_{run.args.name}_oursaff.pt")
            torch.save(model_input["fmap"] ,f"/data/cameron/feat_exp_exps/featexp_{run.args.name}_oursfmap.pt")
        step+=1

# Data/args setup and run
import argparse
parser = argparse.ArgumentParser(description='simple training job')
# logging parameters
parser.add_argument('-n','--name', type=str,default="",required=False,help="wandb training name")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
parser.add_argument('-o','--online', default=False, action='store_true')
parser.add_argument('-s','--save_model', default=True, action='store_true')
parser.add_argument('--viser', default=False, action='store_true')
parser.add_argument('--save_opt_vis', default=False, action='store_true')
# data/training parameters
parser.add_argument('-d','--dataset', type=str,default="hydrant")
parser.add_argument('--imgpath', type=str,default="")
parser.add_argument('-b','--batch_size', type=int,default=1,help="number of videos/sequences per training step")
parser.add_argument('-v','--vid_len', type=int,default=6,help="video length or number of images per batch")
parser.add_argument('--n_workers',type=int,default=4,help="number of workers per dataloader")
parser.add_argument('--until_save',type=int,default=500,help="number of steps until model save")
parser.add_argument('--lr',type=float,default=1e-4,help="learning rate")
parser.add_argument('--n_train_steps',type=int,default=int(1e8),help="learning rate")
parser.add_argument('--overfit', default=True, action='store_true',help="Whether to overfit on a single scene")
parser.add_argument('--until_img', type=int,default=50,help="Number of steps until image summary. ")
parser.add_argument('--sf', type=float,default=1,help="Image resolution scale factor (fractional is cheaper)")
parser.add_argument('--load_save', default=False, action='store_true',help="Whether to load the previously saved data if overfitting (to avoid running flow again)")
parser.add_argument('--splat_src', type=str,default=None,required=False,help="splat src pt file")
# model parameters
parser.add_argument('--time_stride', type=int,default=1,help="Number of frames to flatten in the encoder like a temporal convolution (to save memory). ")
parser.add_argument('--n_skip', type=int,default=1,help="Number of frames to skip between adjacent frames in dataloader. ")
parser.add_argument('--use_gt_intrinsics', default=True, action='store_true',help="Whether to use GT intrinsics instead of predicting them. Useful for pretraining scene rep.")
parser.add_argument('--point_track', default=True, action='store_true',help="Whether to use point tracking")
parser.add_argument('--export_poses', default=False, action='store_true',help="Export poses when overfitting")
parser.add_argument('--export_feats', default=False, action='store_true',help="Export feats when overfitting")

def make_run(args=None,val=False):
    args = parser.parse_args(args)
    self = argparse.Namespace()
    user = getpass.getuser()
    print(f"user={user}")

    # Wandb init
    run = wandb.init(entity="cameronsmithbusiness",project="biasing",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")
    wandb.run.log_code(".")
    self.save_dir = "/tmp/"+args.name#os.path.join(os.environ.get('LOGDIR', "") , run.name)
    os.makedirs(self.save_dir,exist_ok=True)
    wandb.save(os.path.join(self.save_dir, "checkpoint*"))
    wandb.save(os.path.join(self.save_dir, "video*"))

    self.args=args
    self.wandb=run
    if args.viser: self.viser_server=viser.ViserServer()

    # Make dataset
    #args.n_skip=1 # override since optical flow not when skipping
    #self.dataset = [ data.ImageFolder(path=args.imgpath,num_trgt=args.vid_len+1,n_skip = args.n_skip,sf=args.sf), ]
    imgfolders = []
    glob_str = "/data/cameron/monocular_ests/pets_dogs/*"
    #glob_str = "/data/cameron/monocular_ests/pets_dogs/1037_30381_22807"
    #glob_str = "/data/cameron/monocular_ests/*/"
    #for x in sorted(glob(glob_str), key=os.path.getmtime):
    ##for x in ["/data/cameron/monocular_ests/bear","/data/cameron/monocular_ests/blackswan/","/data/cameron/monocular_ests/horns/",]:
    ##for x in ["/data/cameron/monocular_ests/robotics/",]:
    #    try: imgfolders.append( data.ImageFolder(path=x,num_trgt=args.vid_len+1,n_skip = args.n_skip,sf=args.sf) )
    #    except:print("skipping imgfolder x")
    #    if len(imgfolders)>=400:
    #        print("done collecting")
    #        break
    #self.dataset = imgfolders
    self.dataset=None
    print("dummy multivid test")
    # Make model and load checkpoint
    self.model = models.FlowMap(args).cuda()
    if args.init_ckpt is not None:
        ckpt_file = args.init_ckpt if os.path.isfile(os.path.expanduser(args.init_ckpt)) else max(glob(os.path.join(args.init_ckpt,"*.pt")), key=os.path.getctime)
        if 1:
            self.model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
        else:
            current_model_dict = self.model.state_dict()
            loaded_state_dict = torch.load(ckpt_file)["model_state_dict"]
            new_state_dict={k:v for k,v in loaded_state_dict.items() if "aff" not in k}
            self.model.load_state_dict(new_state_dict, strict=False)
            print("not loading aff emb")
    #self.model = torch.nn.DataParallel(self.model)

    return self

run = make_run()
torch.autograd.set_detect_anomaly(True)
train_flowmap(run,run.dataset,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
