from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt; 
import cv2
import os
import multiprocessing as mp
import torch.nn.functional as F
import torch
import random
import imageio
import numpy as np
from glob import glob
from collections import defaultdict
from pdb import set_trace as pdb
from itertools import combinations
from random import choice
import matplotlib.pyplot as plt
import imageio.v3 as iio

from torchvision import transforms

import sys

from glob import glob
import os
import gzip
import json
import numpy as np

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="src dir")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save poses")
args = parser.parse_args()

#scene="horns" # todo take as arg
root = Path(args.input_dir)#"/data/nerf_llff_data/"

base_path="/data/co3dhydrants/co3d/hydrants/"

from collections import defaultdict
sequences = defaultdict(list)
all_cats = [ "hydrant","teddybear","apple", "ball", "bench", "cake", "donut", "plant", "suitcase", "vase","backpack", "banana", "baseballbat", 
             "baseballglove",  "bicycle", "book", "bottle", "bowl", "broccoli",  "car", "carrot", "cellphone", "chair", "couch", "cup",  
             "frisbee", "hairdryer", "handbag", "hotdog", "keyboard", "kite", "laptop", "microwave", "motorcycle", "mouse", "orange", "parkingmeter", 
             "pizza",  "remote", "sandwich", "skateboard", "stopsign",  "toaster", "toilet", "toybus", "toyplane", "toytrain", "toytruck", "tv", "umbrella",  "wineglass", ]


cat="hydrant"
print(cat)
dataset = json.loads(gzip.GzipFile(os.path.join(base_path,cat,"frame_annotations.jgz"),"rb").read().decode("utf8"))
seq_data=[]
for i,data in enumerate(dataset): 
    if data["sequence_name"]=="106_12648_23157":
        seq_data.append(data)
seq_data=sorted(seq_data,key=lambda x:x["frame_number"])
c2ws=[]
for data in seq_data:

    # Below pose processing taken from co3d github issue
    f = data["viewpoint"]["focal_length"]
    h, w = data["image"]["size"]

    R = np.asarray(data["viewpoint"]["R"]).T   # note the transpose here
    T = np.asarray(data["viewpoint"]["T"]) 
    pose = np.concatenate([R,T[:,None]],1)
    pose = torch.from_numpy( np.diag([-1,-1,1]).astype(np.float32) @ pose )# flip the direction of x,y axis
    tmp=torch.eye(4)
    tmp[:3,:4]=pose
    c2ws.append(tmp.inverse())
c2w=torch.stack(c2ws).float()
torch.save(c2w[::].cpu(),"%s/poses.pt"%args.output_dir)
