import os,time
import torch,wandb
from tqdm import trange
from einops import rearrange
import vis,geometry
from copy import deepcopy
import numpy as np
import piqa,kornia
from torchvision.utils import make_grid
from einops import rearrange, repeat
from our_models import ch_sec
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import getpass
from glob import glob

import data
import our_models as models

def to_gpu(ob): return {k: to_gpu(torch.tensor(v)) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

def train_flowmap(run,until_img=25,until_vid=100,until_save=500,optim=None,single_data=None):

    def loss_fn(model_out, gt, model_input,model,step):

        rays = lambda x,y: torch.stack([x[i,:,y[i].long()] for i in range(len(x))])
        losses = { }
        losses["metrics/cam_pointmap_loss"] = (model_input["cam_pointmap"] - model_out["cam_pointmap"]).square().mean()*3e2
        losses["metrics/obj_pointmap_loss"]   = (model_input["obj_pointmap"] - model_out["obj_pointmap"]).square().mean()*4e2
        losses["metrics/segs"]   = (model_input["segs"].float() - model_out["segs"]).square().mean()*1e2
        print(losses)

        return losses

    losses_agg=[]
    optim = torch.optim.Adam(lr=run.args.lr, params=run.model.parameters())

    #get_loader = lambda dataset: torch.utils.data.DataLoader(dataset, batch_size=run.args.batch_size, num_workers=min(run.args.n_workers,run.args.batch_size),pin_memory=True)
    #datasets=[data.PointCloudFolder(run.args.dirpath), data.PointCloudFolder(run.args.dirpath,val=True)]#, data.PointCloudFolder(run.args.dirpath.replace("train","test"),val=True)]
    #dataloaders=[get_loader(x) for x in datasets]
    #dataloader_train,dataloader_val = [iter(x) for x in dataloaders]

    # Train loop
    step=0
    for step_ in trange(run.args.n_train_steps, desc="Fitting"): # train until user interruption

        of=0
        val=False#2<step%100<7 and not of
        suffix="val" if val else ""

        # Get data
        if of:print("overfitting on sample")
        if step_==0 or not of:
            try:
                if val: model_input=ground_truth= to_gpu(next(dataloader_val))
                else:model_input=ground_truth= to_gpu(next(dataloader_train))
            except StopIteration:
                dataloader_train,dataloader_val= [iter(x) for x in dataloaders]
                step-=1
                continue
            model_input["fg"] = fg = model_input["segs"].max(dim=1)[0][:,None]
            model_input["segs"]=torch.cat((~fg,model_input["segs"]),1)

        # Run model and calculate losses
        total_loss = 0.
        out=run.model(model_input)

        # mask background for pointmaps in input and pred also add background channel 
        for k in ["cam_pointmap","obj_pointmap"]:
            for d in [model_input,out]: 
                d[k] = d[k]*fg
        
        losses = loss_fn(out, ground_truth, model_input,run.model,step)
        for loss_name, loss in losses.items():
            wandb.log({loss_name+suffix: loss.item()}, step=step)
            total_loss += loss
        wandb.log({"loss": total_loss.item()}, step=step)

        if not val :total_loss.backward();optim.step();
        optim.zero_grad(); 

        with torch.no_grad(): 
            wandb_imgs=None
            #try:
            if step%until_img==0 or val: wandb_imgs=vis.wandb_summary( 0, out, model_input, ground_truth, None,step=step,prefix=suffix)
            #except:print("skipping vis because error")
        if step%until_save == 0 and step and run.args.save_model: # save model
            print(f"Saving to {run.save_dir}"); torch.save({ 'step': step, 'model_state_dict': run.model.state_dict(), }, os.path.join(run.save_dir, f"checkpoint.pt")) 
        step+=1

# Data/args setup and run
import argparse
parser = argparse.ArgumentParser(description='simple training job')
# logging parameters
parser.add_argument('-n','--name', type=str,default="",required=False,help="wandb training name")
parser.add_argument('-c','--init_ckpt', type=str,default=None,required=False,help="File for checkpoint loading. If folder specific, will use latest .pt file")
parser.add_argument('-o','--online', default=False, action='store_true')
parser.add_argument('-s','--save_model', default=False, action='store_true')
parser.add_argument('--viser', default=False, action='store_true')
parser.add_argument('--save_opt_vis', default=False, action='store_true')
# data/training parameters
parser.add_argument('-d','--dataset', type=str,default="hydrant")
parser.add_argument('--imgpath', type=str,default="")
parser.add_argument('--dirpath', type=str,default="/data/cameron/robot_calibration_testing/kuka_render/")#rotation_gripper_dataset/")
parser.add_argument('-b','--batch_size', type=int,default=1,help="number of videos/sequences per training step")
parser.add_argument('-v','--vid_len', type=int,default=6,help="video length or number of images per batch")
parser.add_argument('--n_workers',type=int,default=4,help="number of workers per dataloader")
parser.add_argument('--until_save',type=int,default=500,help="number of steps until model save")
parser.add_argument('--lr',type=float,default=1e-4,help="learning rate")
parser.add_argument('--n_train_steps',type=int,default=10000,help="n train steps ")
parser.add_argument('--overfit', default=True, action='store_true',help="Whether to overfit on a single scene")
parser.add_argument('--until_img', type=int,default=50,help="Number of steps until image summary. ")
parser.add_argument('--sf', type=float,default=1,help="Image resolution scale factor (fractional is cheaper)")
parser.add_argument('--load_save', default=False, action='store_true',help="Whether to load the previously saved data if overfitting (to avoid running flow again)")
# model parameters
parser.add_argument('--noncanon', default=False, action='store_true',help="use canon frame model or naive baseline")

import pybullet as p
import pybullet_data
physicsClient = p.connect(p.DIRECT);p.setGravity(0, 0, -9.8)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
rob_id = p.loadURDF("franka_panda/panda.urdf")
#num_joints = p.getNumJoints(rob_id)

def make_run(args=None,val=False):
    args = parser.parse_args(args)
    self = argparse.Namespace()
    user = getpass.getuser()
    print(f"user={user}")

    # Wandb init
    run = wandb.init(entity="cameronsmithbusiness",project="biasing",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")
    wandb.run.log_code(".")
    self.save_dir = "/tmp/"+args.name#os.path.join(os.environ.get('LOGDIR', "") , run.name)
    os.makedirs(self.save_dir,exist_ok=True)
    wandb.save(os.path.join(self.save_dir, "checkpoint*"))
    wandb.save(os.path.join(self.save_dir, "video*"))

    self.args=args
    self.wandb=run
    if args.viser: self.viser_server=viser.ViserServer()

    # Make model and load checkpoint
    #self.model = (models.CanonRobotTrajPredTransfBaselineNaive if args.noncanon else models.CanonRobotTrajPredTransfEuclidean)(args).cuda() 
    self.model = models.CanonRobotTrajPredTransfEuclidean(args).cuda() 
    if args.init_ckpt is not None:
        ckpt_file = args.init_ckpt if os.path.isfile(os.path.expanduser(args.init_ckpt)) else max(glob(os.path.join(args.init_ckpt,"*.pt")), key=os.path.getctime)
        self.model.load_state_dict(torch.load(ckpt_file)["model_state_dict"],strict=False)
    return self

run = make_run()
torch.autograd.set_detect_anomaly(False)
train_flowmap(run,until_save=run.args.until_save, until_vid=100 if not run.args.overfit else 300, until_img=run.args.until_img)
