# todo this should probably be just replaced with a more general 'colmap est' with rigid mask? or that leap vo paper?
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import torchvision.transforms as tf
from einops import rearrange, repeat
from jaxtyping import Float, Int64
from omegaconf import DictConfig
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from tqdm import tqdm

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-i','--input_dir',  type=str,default="",required=False,help="src dir")
parser.add_argument('-o','--output_dir', type=str,default="",required=False,help="where to save poses")
args = parser.parse_args()

#scene="horns" # todo take as arg
root = Path(args.input_dir)#"/data/nerf_llff_data/"

# Load the metadata.
metadata = np.load(root / "poses_bounds.npy")
metadata = torch.tensor(metadata)

# Extract extrinsics (rotation and translation), intrinsics (image size and
# focal length), and near/far values.
b, _ = metadata.shape
cameras = rearrange(metadata[:, :-2], "b (i j) -> b i j", i=3, j=5)
rotation = cameras[:, :3, :3]
translation = cameras[:, :3, 3]
h, w, f = cameras[:, :3, 4].unbind(dim=-1)

# Load the extrinsics.
extrinsics = repeat(torch.eye(4), "i j -> b i j", b=b).clone()
extrinsics[:, :3, :3] = rotation
extrinsics[:, :3, 3] = translation

# Convert the extrinsics to OpenCV-style camera-to-world format.
conversion = torch.zeros((4, 4), dtype=torch.float32)
conversion[0, 1] = 1
conversion[1, 0] = 1
conversion[2, 2] = -1
conversion[3, 3] = 1
extrinsics = extrinsics @ conversion
#torch.save(extrinsics.cpu(),"/data/cameron/monocular_ests/%s/poses.pt"%scene)
torch.save(extrinsics.cpu(),"%s/poses.pt"%args.output_dir)
