import os
import cv2
import copy
import math
import argparse
import numpy as np
from time import time
from tqdm import tqdm
from easydict import EasyDict
import imageio
from einops import rearrange

import wandb

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt 

from torcheval.metrics import FrechetInceptionDistance

from data import get_metadata, get_dataset, fix_legacy_dict
import unets, models, diffusion

from torchvision.utils import make_grid

import vis_scene_graph
import sgvis2

def to_gpu(ob): return {k: to_gpu(v) for k, v in ob.items()} if isinstance(ob, dict) else ob.cuda()

class loss_logger:
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.loss = []
        self.start_time = time()
        self.ema_loss = None
        self.ema_w = 0.9

    def log(self, v, display=False):
        self.loss.append(v)
        if self.ema_loss is None:
            self.ema_loss = v
        else:
            self.ema_loss = self.ema_w * self.ema_loss + (1 - self.ema_w) * v
        if display:
            print( f"Steps: {len(self.loss)}/{self.max_steps} \t loss (ema): {self.ema_loss:.3f} " + f"\t Time elapsed: {(time() - self.start_time)/3600:.3f} hr")

def main():
    parser = argparse.ArgumentParser("Minimal implementation of diffusion models")
    # diffusion model
    parser.add_argument("--arch", type=str, help="Neural network architecture")
    parser.add_argument( "--class-cond", action="store_true", default=False, help="train class-conditioned diffusion model",)
    parser.add_argument( "--diffusion-steps", type=int, default=1000, help="Number of timesteps in diffusion process",)
    parser.add_argument( "--sampling-steps", type=int, default=250, help="Number of timesteps in diffusion process",)
    parser.add_argument( "--ddim", action="store_true", default=False, help="Sampling using DDIM update step",)
    parser.add_argument( "--no-diffusion", action="store_true", default=False, help="whether to use diffusion target",)
    parser.add_argument( "--no-labels", action="store_true", default=False, help="whether to use the gt labels",)
    parser.add_argument( "--always-skip", action="store_true", default=False, help="whether to use the skip connection always",)
    parser.add_argument( "--no-gt-labels", action="store_true", default=False, help="whether to use gt labels target if dataset provides it",)
    # dataset
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--data-dir", type=str, default="./dataset/")
    # optimizer
    parser.add_argument( "--batch-size", type=int, default=128, help="batch-size per gpu")
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--epochs", type=int, default=100000000)
    parser.add_argument("--ema_w", type=float, default=0.9995)
    # sampling/finetuning
    parser.add_argument("--pretrained-ckpt", type=str, help="Pretrained model ckpt")
    parser.add_argument("--delete-keys", nargs="+", help="Pretrained model ckpt")
    parser.add_argument( "--sampling-only", action="store_true", default=False, help="No training, just sample images (will save them in --save-dir)",)
    parser.add_argument( "--num-sampled-images", type=int, default=50000, help="Number of images required to sample from the model",)
    # misc
    parser.add_argument("--save-dir", type=str, default="./trained_models/")
    parser.add_argument("--name", type=str, default="test")
    parser.add_argument("--n_workers", default=8, type=int)
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument("--seed", default=112233, type=int)
    parser.add_argument("--online", default=False,required=False)
    parser.add_argument("--no_sampling", default=False,required=False)

    # setup
    args = parser.parse_args()
    metadata = get_metadata(args.dataset)
    torch.backends.cudnn.benchmark = True
    args.device = "cuda:{}".format(args.local_rank)
    torch.cuda.set_device(args.device)
    torch.manual_seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    if args.local_rank == 0:
        print(args)

    run = wandb.init(entity="cameronsmithbusiness",project="biasing",mode="online" if args.online else "disabled",name=args.name,dir=f"/tmp/wandb")
    wandb.run.log_code(".")

    # Creat model and diffusion process
    model = models.__dict__[args.arch](
        image_size=metadata.image_size,
        in_channels=metadata.num_channels,
        out_channels=metadata.num_channels,
        num_classes=metadata.num_classes if args.class_cond else None,
    ).to(args.device)

    #print("DISABLING DEC GRADIENTS")
    #for param in model.output_blocks.parameters(): param.requires_grad = False
    #for param in model.out.parameters(): param.requires_grad = False

    diffusion_ = diffusion.GaussianDiffusion(args.diffusion_steps, args.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    # load pre-trained model
    if args.pretrained_ckpt:
        print(f"Loading pretrained model from {args.pretrained_ckpt}")
        d = fix_legacy_dict(torch.load(args.pretrained_ckpt, map_location=args.device))
        dm = model.state_dict()
        if args.delete_keys:
            for k in args.delete_keys:
                print( f"Deleting key {k} becuase its shape in ckpt ({d[k].shape}) doesn't match " + f"with shape in model ({dm[k].shape})") 
                del d[k]
        model.load_state_dict(d, strict=False)
        print( f"Mismatched keys in ckpt and model: ", set(d.keys()) ^ set(dm.keys()),)
        print(f"Loaded pretrained model from {args.pretrained_ckpt}")

    # Load dataset
    train_set = get_dataset(args.dataset, args.data_dir, metadata)
    train_loader = DataLoader( train_set, batch_size=args.batch_size, shuffle=True, sampler=None, num_workers=args.n_workers, pin_memory=True,)
    if args.local_rank == 0: print( f"Training dataset loaded: Number of batches: {len(train_loader)}, Number of images: {len(train_set)}")
    logger = loss_logger(len(train_loader) * args.epochs)

    # ema model
    args.ema_dict = copy.deepcopy(model.state_dict())

    # lets start training the model
    global_step=-1
    for epoch in range(args.epochs):

        model.train()
        for step, model_input in tqdm(enumerate(train_loader)): 
            global_step+=1
            #assert (model_input["rgb"].max().item() <= 1) and (0 <= model_input["rgb"].min().item())
            #print(step)

            model_input = to_gpu(model_input)

            #if global_step%2==0:
            #    model_input["scene_graph"]=torch.zeros_like(model_input["scene_graph"])
            #    model_input["cameras"]=torch.zeros_like(model_input["cameras"])

            log_dict={}

            # must use [-1, 1] pixel range for images
            #images, labels = ( 2 * images.to(args.device) - 1, labels.to(args.device) if args.class_cond else None,)
            images = model_input["rgb"]
            t = torch.randint(diffusion_.timesteps, (len(images),), dtype=torch.int64).to(args.device)
            xt_rgb, eps_rgb = diffusion_.sample_from_forward_process(images, t)
            if not args.no_diffusion:
                model_input["noised_rgb"] = xt_rgb
                model_input["eps_rgb"] = eps_rgb

            if "scene_graph" in model_input and 1:
                xt_scene_graph, eps_scene_graph = diffusion_.sample_from_forward_process(rearrange(model_input["scene_graph"],"b o c -> b 1 c o 1"), t)
                model_input["noised_scene_graph"] = rearrange(xt_scene_graph,"b 1 c o 1 -> b o c ")
                model_input["eps_scene_graph"] = rearrange(eps_scene_graph,"b 1 c o 1 -> b o c ")

            if "cameras" in model_input and 1:
                xt_cameras, eps_cameras = diffusion_.sample_from_forward_process(rearrange(model_input["cameras"],"b o c -> b 1 c o 1"), t)
                model_input["noised_cameras"] = rearrange(xt_cameras,"b 1 c o 1 -> b o c ")
                model_input["eps_cameras"] = rearrange(eps_cameras,"b 1 c o 1 -> b o c ")

            #args.no_diffusion= (step%2==1 and not args.no_labels)
            #print("alternating diffusion",args.no_diffusion)
            if args.no_diffusion and 0:
                for k,v in model_input.items():
                    if "noised" in k: model_input[k]=model_input[k[7:]]
                    if "eps" in k: model_input[k]=torch.zeros_like(v)
                t=torch.zeros_like(t)

            # Modification for NVS task - remove when doing actual video diffusion 
            #model_input["eps"]*=0

            # Randomly alternate between training with skip connection and noise prediction vs just using global bottleneck and direct prediction 
            use_skip = step%2==0# or args.no_labels or args.always_skip

            # Run model
            model_out = model(model_input, t, use_skip=True,teacher_forcing=use_skip)

            # Losses
            loss=0
            for k in model_out:
                if k in model_input or "loss" in k:

                    if "rgb" not in k and "eps" not in k : continue
                    #if "eps_rgb" in k and not use_skip:continue
                    #if k!="rgb":continue
                    if "loss" not in k:
                        if args.no_labels and "rgb" not in k and "eps" not in k or not (model_input[k].shape==model_out[k].shape): continue

                    if "loss" in k: loss_ = model_out[k].mean()
                    else: loss_= (model_input[k]-model_out[k]).square().mean()

                    #if not args.no_diffusion and "eps" not in k: loss_/=10

                    ## Loss weights
                    #if k=="eps_rgb": loss_/=4
                    if "scene_graph"==k:loss_*=6
                    if "seg"==k:loss_*=1e1
                    if "invdepth"==k:loss_*=5e0
                    if "cams"==k:loss_*=1e-2

                    log_dict["metrics/%s_loss"%k]=loss_.item()
                    loss += loss_
            # Use below if doing autodecoder
            #log_dict["metrics/latent_penalty"]=model_out["global_latent"].square().mean()/2
            #loss += log_dict["metrics/latent_penalty"]
            #print(model_out.keys())
            print(log_dict)

            # vis all output in vis script and move script to new file

            # for k,v in 

            save_local=False

            optimizer.zero_grad()
            if not save_local and len(log_dict):
                loss.backward()
                optimizer.step()
            else:print("not doing backward")

            if ( (global_step%100 in [0,1][:]) or (global_step%30==0 and global_step<100) ) and 1: # Visualize images
                print("Logging images")
                img_grids = []

                scene_graphs = []
                if "scene_graph" in model_input: scene_graphs+=[("scene_graph_inp",model_input["scene_graph"],model_input["cameras"])]
                if "noised_scene_graph" in model_input: scene_graphs+=[("noised_scene_graph_inp",model_input["noised_scene_graph"],model_input["cameras"])]
                if "scene_graph" in model_out: scene_graphs+=[("scene_graph_pred_gtcam",model_out["scene_graph"],model_input["cameras"])]
                if "cameras" in model_out: scene_graphs+=[("scene_graph_pred",model_out["scene_graph"],model_out["cameras"])]

                if "bboxs" in model_input:
                    min_bbox_vals,max_bbox_vals=model_input["bboxs"][0].unbind(1)
                    bboxs_img = sgvis2.visualize_colored_point_cloud_matplotlib(model_input["point_cloud"][0].cpu(), model_input["rgb"][0,0].permute(1,2,0).flatten(0,1).clip(-1,1).cpu()*.5+.5, None, min_bbox_vals[:].cpu(), max_bbox_vals[:].cpu())
                    model_input["inp_vis_bbox"]=torch.nn.functional.interpolate(torch.from_numpy(bboxs_img).permute(2,0,1)[None],scale_factor=.5)[:,None].float()/255 * 2-1

                if "scene_graph_intermed" in model_out: scene_graphs+=[("scene_graph_intermed",
                                                                        rearrange(model_out["scene_graph_intermed"][:1],"b 1 t c o 1 -> (b t) o c"),
                                                                        rearrange(model_out["cameras_intermed"][:1],"b 1 t c 1 1 -> (b t) 1 c"))]

                for k,sg,cam_src in scene_graphs:
                    bbox_vis_render=vis_scene_graph.plot_3d_bounding_boxes((sg[:8] if not save_local else sg).detach().cpu(),cam_src.detach().cpu())
                    bbox_vis,renderer_vis=[x[0] for x in bbox_vis_render],[x[1] for x in bbox_vis_render]
                    bbox_vis=torch.stack([torch.from_numpy(x) for x in bbox_vis]).permute(0,3,1,2)
                    model_input["%s_vis_bbox"%k]=torch.nn.functional.interpolate(bbox_vis,scale_factor=.5)[:,None].float()/255 * 2-1
                    model_input["%s_vis_renderer"%k]=torch.stack([x for x in renderer_vis]).permute(0,3,1,2)[:,None]
                for pref,dic in [("input/",model_input),("output/",model_out)]:
                    dic={k:v.flatten(0,1) for k,v in dic.items()}
                    for k,v in dic.items():
                        if (any(x in k for x in ["loss","eps","cameras","cams","idx","global_latent","clip","bboxs0","bboxs1","point_cloud"]) or 
                                k[-5:]=="graph" or ("graph" in k and "vis" not in k)):continue
                        if not save_local:v=v[:16]# just log first N batch elements for size constraints
                        #if k not in ["rgb","eps","noised_rgb"] and args.no_labels:continue
                        tmp=[("",v)] if v.size(1)!=2 else [("0",v[:,[0]]),("1",v[:,[1]])] # if img is two channel instead of 3, split it
                        for suf,v_ in tmp:
                            print(k)
                            img_grids.append( (pref+k+suf,make_grid(v_.cpu().detach(),nrow=model_input["rgb"].size(1) if "intermed" not in k else 5,normalize=True,scale_each=False) ) )
                            if "rgb"==k and 0: # make videos if rgb
                                wandb.log({"vid/"+pref+k+suf:
                                    wandb.Video(((v_*.5+.5).unflatten(0,(-1,8)).permute(1,2,0,3,4).flatten(2,3).cpu().detach()*256).to(torch.uint8).numpy(), fps=6,format="mp4")
                                    }, step=global_step)

                if save_local:
                    print("saving images locally")
                    for k,v in img_grids:
                        try: plt.imsave("images/"+k+".png", v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) 
                        except:continue
                    if 0 and model_out["rgb"].size(1)>1:
                        vid_frames=((model_out["rgb"]*.5+.5)*255).permute(1,3,0,4,2).flatten(2,3).detach().cpu().numpy().astype('uint8')
                        writer = imageio.get_writer('images/model_out.mp4', fps=8);
                        for x in vid_frames:writer.append_data(x)
                        writer.close()
                    if args.no_diffusion:zz
                wandb.log({"img/"+k: wandb.Image(v.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) for k,v in img_grids},step=global_step)
            #if global_step%550==0 and not args.no_diffusion and 1: # Sample images
            if global_step%100 in [0,1][:1] and not args.no_diffusion: # Sample scene graphs
                print("doing scene graph sampling")
                n_frames = model_input["noised_rgb"].size(1)
                n_gen=8
                direct = global_step%100 == 1
                sampled_sg, intermed_samples_ = diffusion.sample_N_images( n_gen, model, diffusion_, 
                                                {"rgb":model_input["rgb"]},#,"scene_graph":model_input["scene_graph"],"cameras":model_input["cameras"]},  
                                                #{"rgb":model_input["rgb"]},  
                                                args.sampling_steps, n_gen, metadata.num_channels, 
                                                metadata.image_size, n_frames, metadata.num_classes, args,scene_graph=True,#cameras=model_input["cameras"],
                                                #conditioning={"scene_graph":model_input["scene_graph"][:n_gen],"cameras":model_input["cameras"][:n_gen]} if 0 
                                                #else {"rgb":model_input["rgb"][:n_gen]} if 0 else 
                                                conditioning={k:model_input[k][:n_gen] for k in ["bboxs","clip_embs"]},
                                                use_direct=direct)
                #intermed_samples_=intermed_samples_.squeeze(-1).squeeze(1).permute(0,1,3,2)

                if "rgb" in sampled_sg:
                    sampled_grid = make_grid(sampled_sg["rgb"].flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                    wandb.log({"sampled_images"+("direct" if direct else ""):wandb.Image(sampled_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)
                    intermed_grid_rgb = make_grid(intermed_samples_["rgb"].flatten(0,2)*.5+.5,nrow=intermed_samples_["rgb"].size(2),normalize=False)
                    wandb.log({"intermed_images"+("direct" if direct else ""):wandb.Image(intermed_grid_rgb.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)

                    if save_local:
                        plt.imsave("images/"+"sampled_rgb_intermed.png", intermed_grid_rgb.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) 

                if "scene_graph" in sampled_sg:
                    sampled_sg_vis=vis_scene_graph.plot_3d_bounding_boxes(sampled_sg["scene_graph"].detach().cpu(),sampled_sg["cameras"].cpu())
                    sampled_sg_vis,sampled_sg_render_vis = [x[0] for x in sampled_sg_vis],[x[1] for x in sampled_sg_vis]
                    sampled_sg_render_vis=torch.stack(sampled_sg_render_vis).permute(0,3,1,2)
                    sampled_sg_vis=torch.stack([torch.from_numpy(x) for x in sampled_sg_vis]).permute(0,3,1,2)
                    sampled_sg_vis=torch.nn.functional.interpolate(sampled_sg_vis,scale_factor=.5).float()[:,None]/255 * 2-1

                    intermed_sg_vis=vis_scene_graph.plot_3d_bounding_boxes(intermed_samples_["scene_graph"].flatten(0,2).squeeze().detach().cpu().permute(0,2,1),intermed_samples_["cameras"].flatten(0,2).squeeze(-1))#.permute(0,2,1)) #model_input["cameras"][[0]].expand(len(intermed_samples_.flatten(0,1)),-1,-1))
                    intermed_sg_vis,intermed_render_vis = [x[0] for x in intermed_sg_vis],[x[1] for x in intermed_sg_vis]
                    intermed_sg_vis=torch.stack([torch.from_numpy(x) for x in intermed_sg_vis]).permute(0,3,1,2)
                    intermed_sg_vis=torch.nn.functional.interpolate(intermed_sg_vis,scale_factor=.5).float()/255 * 2-1

                    intermed_render_vis=torch.stack(intermed_render_vis).permute(0,3,1,2)

                    #bbox_vis,renderer_vis=[x[0] for x in bbox_vis_render],[x[1] for x in bbox_vis_render]
                    #bbox_vis=torch.stack([torch.from_numpy(x) for x in bbox_vis]).permute(0,3,1,2)
                    #model_input["%s_vis_bbox"%k]=torch.nn.functional.interpolate(bbox_vis,scale_factor=.5)[:,None].float()/255 * 2-1
                    #model_input["%s_vis_renderer"%k]=torch.stack([x for x in renderer_vis]).permute(0,3,1,2)[:,None]

                    #sampled_images = torch.from_numpy(sampled_images).permute(0,3,1,2).cpu().detach()/255 # todo just remove processing in sample_n
                    sampled_grid = sampled_sg_grid = make_grid(sampled_sg_vis.flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                    wandb.log({"sampled_bboxs":wandb.Image(sampled_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)

                    sampled_grid = make_grid(sampled_sg_vis.flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                    sampled_render_grid = make_grid(sampled_sg_render_vis*.5+.5,nrow=n_frames,normalize=False)
                    wandb.log({"sampled_renderer_vis":wandb.Image(sampled_render_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)

                    sampled_grid = make_grid(sampled_sg_vis.flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                    sampled_grid_renderer = make_grid(intermed_render_vis[:intermed_samples_["scene_graph"].size(2)*2]*.5+.5,nrow=intermed_samples_["scene_graph"].size(2),normalize=False)
                    wandb.log({"sampled_renderer_vis_intermed":wandb.Image(sampled_grid_renderer.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)
                    #wandb.log({"vid/output/sampled_vid": wandb.Video(((sampled_images*.5+.5).permute(1,2,0,3,4).flatten(2,3).cpu().detach()*256).to(torch.uint8).numpy(), fps=6,format="mp4")}, step=global_step)

                    #intermed_images = torch.from_numpy(intermed_samples).flatten(0,2).cpu().detach()/255 # todo just remove processing in sample_n
                    sampled_grid = make_grid(sampled_sg_vis.flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                    intermed_grid = make_grid(intermed_sg_vis[:intermed_samples_["scene_graph"].size(2)*2]*.5+.5,nrow=intermed_samples_["scene_graph"].size(2),normalize=False)
                    wandb.log({"intermed_bboxes":wandb.Image(intermed_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)

                    if save_local:
                        plt.imsave("images/"+"sampled_sg_vis.png", sampled_sg_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) 
                        #plt.imsave("images/"+"intermed_images.png", intermed_grid_rgb.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) 
                        #if model_out["rgb"].size(1)>1:
                        #    vid_frames_ = ((rearrange(intermed_samples_[:,:,:],"v b f c x y ->  b (x) (f v y) c").cpu().numpy()*.5+.5)*255).astype('uint8')
                        #    with imageio.get_writer('images/intermed_vids.mp4', fps=8) as writer:[writer.append_data(x) for x in vid_frames_]
                        #    vid_frames_ = ((rearrange(intermed_samples_[:,:,[-1]],"v b f c x y ->  b (x) (f v y) c").cpu().numpy()*.5+.5)*255).astype('uint8')
                        #    with imageio.get_writer('images/final_vids.mp4', fps=8) as writer:[writer.append_data(x) for x in vid_frames_]
                        zz

            if global_step%100==0 and not args.no_diffusion and "eps_rgb" in model_out and 0: # Sample images
                print("doing sampling")
                n_frames = model_input["noised_rgb"].size(1)
                n_gen=8
                sampled_images, intermed_samples_ = diffusion.sample_N_images( n_gen, model, diffusion_, model_input["rgb"], args.sampling_steps, n_gen, metadata.num_channels, 
                                                                                 metadata.image_size, n_frames, metadata.num_classes, args,)
                #sampled_images = torch.from_numpy(sampled_images).permute(0,3,1,2).cpu().detach()/255 # todo just remove processing in sample_n
                sampled_grid = make_grid(sampled_images.flatten(0,1)*.5+.5,nrow=n_frames,normalize=False)
                wandb.log({"sampled_images":wandb.Image(sampled_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)
                #wandb.log({"vid/output/sampled_vid": wandb.Video(((sampled_images*.5+.5).permute(1,2,0,3,4).flatten(2,3).cpu().detach()*256).to(torch.uint8).numpy(), fps=6,format="mp4")}, step=global_step)

                #intermed_images = torch.from_numpy(intermed_samples).flatten(0,2).cpu().detach()/255 # todo just remove processing in sample_n
                intermed_grid = make_grid(intermed_samples_.flatten(0,2)*.5+.5,nrow=intermed_samples_.size(2),normalize=False)
                wandb.log({"intermed_images":wandb.Image(intermed_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy())},step=global_step)

                if save_local:
                    plt.imsave("images/"+"intermed_images.png", intermed_grid.permute(1, 2, 0).float().detach().clip(0,1).cpu().numpy()) 
                    if model_out["rgb"].size(1)>1:
                        vid_frames_ = ((rearrange(intermed_samples_[:,:,:],"v b f c x y ->  b (x) (f v y) c").cpu().numpy()*.5+.5)*255).astype('uint8')
                        with imageio.get_writer('images/intermed_vids.mp4', fps=8) as writer:[writer.append_data(x) for x in vid_frames_]
                        vid_frames_ = ((rearrange(intermed_samples_[:,:,[-1]],"v b f c x y ->  b (x) (f v y) c").cpu().numpy()*.5+.5)*255).astype('uint8')
                        with imageio.get_writer('images/final_vids.mp4', fps=8) as writer:[writer.append_data(x) for x in vid_frames_]
                    zz

                print("done sampling")
                if global_step%(550*1)==0 and 0: # FID comp
                    print("Calculating FID")
                    fid = FrechetInceptionDistance()
                    real_imgs = torch.cat([next(iter(train_loader))["rgb"] for _ in range(500//(args.batch_size*n_frames))]);print("got real images")
                    if real_imgs.size(1)==1: 
                        real_imgs=real_imgs.expand(-1,3,-1,-1)
                        sampled_images=sampled_images.expand(-1,3,-1,-1)
                    if real_imgs.min()<0: real_imgs=real_imgs*.5+.5
                    if sampled_images.min()<0: sampled_images=sampled_images*.5+.5
                    fid.update(real_imgs.flatten(0,1).clip(0,1), is_real=True)
                    fid.update(sampled_images.flatten(0,1).clip(0,1), is_real=False)
                    fid_score = fid.compute();print(f"FID: {fid_score:.3f}")
                    log_dict|={"metrics/FID":fid_score.item()}

            log_dict = log_dict|{"metrics/denoising_loss":loss.item()}
            wandb.log(log_dict, step=global_step)

            # update ema_dict
            new_dict = model.state_dict()
            for (k, v) in args.ema_dict.items(): args.ema_dict[k] = ( args.ema_w * args.ema_dict[k] + (1 - args.ema_w) * new_dict[k])
            logger.log(loss.item(), display=not step % 100)

            if global_step%1000==0 and global_step:
                print("saving model")
                torch.save( model.state_dict(), os.path.join( args.save_dir, f"{args.name}_{args.dataset}-epoch_{args.epochs}.pt"))


if __name__ == "__main__":
    with torch.autograd.set_detect_anomaly(False): main()
