import matplotlib.pyplot as plt; 
import cv2
import os
import statistics
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

from einops import rearrange, repeat

sys.path.append("/home/cameronsmith/repos/multivid_point_track_sfm/")
from data import make_sample

ch_sec = lambda x: rearrange(x,"... c x y -> ... (x y) c")
hom = lambda x, i=-1: torch.cat((x, torch.ones_like(x.unbind(i)[0].unsqueeze(i))), i)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="/data/cameron/monocular_ests/pets_dogs/*",required=False,help="dir to pkg")
args = parser.parse_args()

from tqdm import tqdm

n_trgt,num_skip,sf=10,3,1#.6

paths=glob(args.input_dir)
for path in tqdm(paths):
    if os.path.exists(path+"/lowrespkg.pt") and 0:
        print("already done for ",path)
        continue
    try: tracks = list(torch.load(path+"/pred_tracks_offline.pt"))
    except:
        print("skipping bc no pred tracks offline for ",path)
        continue
    imgs = torch.load(path+"/imgs.pt")
    bwd_flow = torch.load(path+"/bwd_flow.pt")
    rig_flow_masks = torch.load(path+"/rig_flow_masks.pt")[:,:1]
    tracks[0] = rearrange(tracks[0],"g b t p c -> b t (p g) c")[0]
    tracks[1] = rearrange(tracks[1],"g b t p -> b t (p g)")[0]
    f = torch.load(path+"/intrinsics.pt")
    print("Done loading data")
    idx=0
    #depths = torch.load(path+"/depth_ests.pt")

    context = []
    trgt = []
    post_input = []

    frames = imgs
    #f=depths[1]
    #depth_frames = depths[0]

    if frames.max()>2: frames=frames/255

    intrinsics = repeat(torch.eye(3), "i j -> b i j", b=len(imgs)).clone()
    intrinsics[:, :2, 2] = 0.5
    intrinsics[:, 0, 0] = f 
    intrinsics[:, 1, 1] = f * imgs.size(-1) / imgs.size(-2)

    org_ratio=frames[0].size(-2)/frames[0].size(-1)
    h,s=3,1
    hi_res=[640, 1024]

    pred_tracks = tracks[0][:n_trgt*num_skip:num_skip]
    pred_visibility = tracks[1][:n_trgt*num_skip:num_skip]
    #downsampling until more scalable approach
    #s=4
    gs=1
    track_sl=64 if 1 else 42#
    try: pred_tracks = rearrange( rearrange(pred_tracks,"t (x y s) c -> (t s) c x y",y=track_sl,x=track_sl)[...,::gs,::gs], "(t s) c x y -> t (x y s) c",t=n_trgt)
    except: print("skipping because unfolding error");continue
    pred_visibility = rearrange( rearrange(pred_visibility,"t (x y s) -> (t s) x y",y=track_sl,x=track_sl)[...,::gs,::gs], "(t s) x y -> t (x y s)",t=n_trgt)
    #from pdb import set_trace as pdb_;pdb_() 

    #rig_flow_masks=torch.ones_like(rig_flow_masks[:,:])
    sample = {
            "intrinsics":intrinsics[:n_trgt*num_skip:num_skip],"rgb":frames[:n_trgt*num_skip:num_skip]* 2-1,
            "org_ratio":org_ratio,
            #"depth_inp":depths[:n_trgt*num_skip:num_skip],
            "rig_flow_masks":rig_flow_masks[:n_trgt*num_skip:num_skip][:-1], 
            "pred_tracks":pred_tracks,
            "pred_visibility":pred_visibility,
            }
    switch=[1,-1][1]
    out_dict= make_sample(sample, 1/org_ratio,hires_factor=h,budget=192*640/(8//s),
            low_res=[int(128*sf),int(224*sf)][::switch],#[::[-1,1][frames.size(-1)>frames.size(-2)]],
            hi_res=hi_res[::-1]#[::[-1,1][frames.size(-1)>frames.size(-2)]])
            )
    if out_dict[0]["rgb"].size(0)!=10: print("skipping x")
    else: torch.save(out_dict,path+"/hirespkg.pt")
    print(path+"/hirespkg.pt")
