"""Debug silhouette alignment by sweeping euler conventions and comparing
original (parquet) vs posed extrinsics.

Usage:
    MUJOCO_GL=egl python debug_alignment.py
"""

import json
import os
import numpy as np
import pandas as pd
import imageio.v3 as iio
import cv2
import mujoco
from scipy.spatial.transform import Rotation as R

DROID_ROOT = "/data/cameron/droid"
MANIFEST_PATH = os.path.join(DROID_ROOT, "manifest_posed_ext2.json")
FRANKA_XML = "/data/cameron/para/droid_testing/franka_panda/panda_nohand.xml"
FRANKA_XML_HAND = "/data/cameron/para/droid_testing/franka_panda/panda.xml"
OUTPUT_DIR = "/data/cameron/para_droid_pretrain/posed_droid/silhouette_debug"
IMG_W, IMG_H = 640, 360


def load_franka_assets(xml_path):
    xml_dir = os.path.dirname(os.path.abspath(xml_path))
    assets = {}
    asset_dir = os.path.join(xml_dir, "assets")
    if os.path.isdir(asset_dir):
        for fn in os.listdir(asset_dir):
            fpath = os.path.join(asset_dir, fn)
            if os.path.isfile(fpath):
                with open(fpath, "rb") as f:
                    assets[os.path.join("assets", fn)] = f.read()
    return assets


ASSETS = load_franka_assets(FRANKA_XML)


def ext6d_to_cam_pose(ext_6d, euler_convention="xyz"):
    """Convert 6D extrinsics to camera pos + MuJoCo rotation.

    Tries different euler conventions to find the right one.
    """
    pos = np.array(ext_6d[:3])
    rot = R.from_euler(euler_convention, ext_6d[3:6]).as_matrix()

    cam_pos = pos
    # OpenCV to MuJoCo: flip Y and Z
    opencv_to_mj = np.diag([1, -1, -1])
    cam_rot = rot @ opencv_to_mj

    return cam_pos, cam_rot


def render_robot(xml_path, joints, cam_pos, cam_rot, width, height, fovy=50.0):
    with open(xml_path) as f:
        xml = f.read()

    cam_x = cam_rot[:, 0]
    cam_y = cam_rot[:, 1]
    cam_str = (
        f'<camera name="posed" pos="{cam_pos[0]} {cam_pos[1]} {cam_pos[2]}" '
        f'xyaxes="{cam_x[0]} {cam_x[1]} {cam_x[2]} {cam_y[0]} {cam_y[1]} {cam_y[2]}" '
        f'fovy="{fovy}"/>'
    )
    xml = xml.replace(
        '<light name="top" pos="0 0 2" mode="trackcom"/>',
        f'<light name="top" pos="0 0 2" mode="trackcom"/>\n    {cam_str}'
    )

    model = mujoco.MjModel.from_xml_string(xml, ASSETS)
    data = mujoco.MjData(model)
    data.qpos[:7] = joints[:7]
    mujoco.mj_forward(model, data)

    renderer = mujoco.Renderer(model, height=height, width=width)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")
    renderer.update_scene(data, camera=cam_id)
    rgb = renderer.render().copy()

    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.disable_depth_rendering()
    renderer.close()

    mask = (depth < depth.max() * 0.99).astype(np.uint8)
    return rgb, mask


def contour_overlay(real_frame, mask, color=(0, 255, 0), thickness=2):
    h, w = mask.shape[:2]
    frame = cv2.resize(real_frame, (w, h))
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    result = frame.copy()
    cv2.drawContours(result, contours, -1, color, thickness)
    fill = frame.copy()
    for c in contours:
        cv2.fillPoly(fill, [c], color)
    result = cv2.addWeighted(result, 0.7, fill, 0.3, 0)
    return result


def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    with open(MANIFEST_PATH) as f:
        manifest = json.load(f)
    episodes = manifest["episodes"]

    # Use ep9 since user flagged it — also try ep2 and ep4
    test_eps = [e for e in episodes if e["ep_idx"] in [2, 4, 9]]

    # Euler conventions to try
    euler_conventions = [
        "xyz",   # intrinsic xyz (what friend's code uses)
        "XYZ",   # extrinsic XYZ
        "zyx",   # intrinsic zyx (common RPY interpretation)
        "ZYX",   # extrinsic ZYX
        "xzy",   # intrinsic xzy
        "yxz",   # intrinsic yxz
    ]

    for ep_info in test_eps:
        ep_idx = ep_info["ep_idx"]
        chunk = ep_idx // 1000
        print(f"\n{'='*60}")
        print(f"Episode {ep_idx} — {ep_info['posed_ep_id']}")

        # Load data
        parquet = os.path.join(DROID_ROOT, f"data/chunk-{chunk:03d}/episode_{ep_idx:06d}.parquet")
        vid = os.path.join(DROID_ROOT, f"videos/chunk-{chunk:03d}/observation.images.exterior_2_left/episode_{ep_idx:06d}.mp4")
        df = pd.read_parquet(parquet)
        joints = np.stack(df["observation.state.joint_position"].values)
        frames = iio.imread(vid, plugin="pyav")

        n = min(len(joints), len(frames))
        mid = n // 2
        real = frames[mid]
        jpos = joints[mid]

        # Get both original (parquet) and posed extrinsics
        orig_ext = np.array(df["camera_extrinsics.exterior_2_left"].iloc[0])
        posed_ext = np.array(ep_info["posed_ext2"])

        print(f"  Original ext: {orig_ext}")
        print(f"  Posed ext:    {posed_ext}")
        print(f"  Diff:         {np.abs(orig_ext - posed_ext)}")
        print(f"  Max diff:     {np.max(np.abs(orig_ext - posed_ext)):.6f}")

        # === Test 1: Euler convention sweep with posed extrinsics ===
        rows = []
        real_resized = cv2.resize(real, (IMG_W, IMG_H))
        cv2.putText(real_resized, f"Real ep{ep_idx} f={mid}", (5, 20),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

        for conv in euler_conventions:
            try:
                cam_pos, cam_rot = ext6d_to_cam_pose(posed_ext, euler_convention=conv)
                rgb, mask = render_robot(FRANKA_XML, jpos, cam_pos, cam_rot, IMG_W, IMG_H, fovy=50.0)
                overlay = contour_overlay(real, mask)
                cv2.putText(overlay, f"posed euler={conv}", (5, 20),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                rows.append(overlay)
            except Exception as e:
                blank = np.zeros((IMG_H, IMG_W, 3), dtype=np.uint8)
                cv2.putText(blank, f"euler={conv}: FAIL {str(e)[:50]}", (5, 20),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
                rows.append(blank)

        # Also test original extrinsics with xyz convention
        try:
            cam_pos, cam_rot = ext6d_to_cam_pose(orig_ext, euler_convention="xyz")
            rgb, mask = render_robot(FRANKA_XML, jpos, cam_pos, cam_rot, IMG_W, IMG_H, fovy=50.0)
            overlay = contour_overlay(real, mask)
            cv2.putText(overlay, "ORIGINAL euler=xyz", (5, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)
            rows.append(overlay)
        except Exception as e:
            blank = np.zeros((IMG_H, IMG_W, 3), dtype=np.uint8)
            cv2.putText(blank, f"ORIGINAL: FAIL", (5, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
            rows.append(blank)

        # Arrange: real + 6 conventions + original = 8 panels
        # Make 2 columns x 4 rows
        all_panels = [real_resized] + rows
        # Pad to even
        while len(all_panels) % 2 != 0:
            all_panels.append(np.zeros((IMG_H, IMG_W, 3), dtype=np.uint8))

        grid_rows = []
        for i in range(0, len(all_panels), 2):
            grid_rows.append(np.concatenate(all_panels[i:i+2], axis=1))
        grid = np.concatenate(grid_rows, axis=0)

        out = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_euler_sweep.png")
        cv2.imwrite(out, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
        print(f"  Euler sweep: {out}")

        # === Test 2: FOV sweep with best-looking convention (xyz) at higher res ===
        # Also try with panda.xml (with hand) to see if gripper helps alignment
        fov_panels = []
        for fovy in [42, 50, 58, 69]:
            for xml_path, label in [(FRANKA_XML, "nohand"), (FRANKA_XML_HAND, "hand")]:
                try:
                    cam_pos, cam_rot = ext6d_to_cam_pose(posed_ext, "xyz")
                    rgb, mask = render_robot(xml_path, jpos, cam_pos, cam_rot, IMG_W, IMG_H, fovy=fovy)
                    overlay = contour_overlay(real, mask)
                    cv2.putText(overlay, f"fov={fovy} {label}", (5, 20),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                    fov_panels.append(overlay)
                except:
                    pass

        if fov_panels:
            while len(fov_panels) % 2 != 0:
                fov_panels.append(np.zeros((IMG_H, IMG_W, 3), dtype=np.uint8))
            grid_rows = []
            for i in range(0, len(fov_panels), 2):
                grid_rows.append(np.concatenate(fov_panels[i:i+2], axis=1))
            grid = np.concatenate(grid_rows, axis=0)
            out = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_fov_model.png")
            cv2.imwrite(out, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
            print(f"  FOV+model: {out}")

        # === Test 3: Side-by-side MuJoCo render vs real at multiple frames ===
        sbs_panels = []
        for fi in [0, mid, n - 1]:
            cam_pos, cam_rot = ext6d_to_cam_pose(posed_ext, "xyz")
            rgb, mask = render_robot(FRANKA_XML, joints[fi], cam_pos, cam_rot, IMG_W, IMG_H, fovy=50.0)
            real_r = cv2.resize(frames[fi], (IMG_W, IMG_H))
            # Pure side-by-side: real | render (no overlay, easier to compare)
            cv2.putText(real_r, f"Real f={fi}", (5, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            cv2.putText(rgb, f"MuJoCo f={fi}", (5, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            sbs_panels.append(np.concatenate([real_r, rgb], axis=1))

        if sbs_panels:
            grid = np.concatenate(sbs_panels, axis=0)
            out = os.path.join(OUTPUT_DIR, f"ep{ep_idx:06d}_side_by_side.png")
            cv2.imwrite(out, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
            print(f"  Side-by-side: {out}")

    print(f"\nAll outputs: {OUTPUT_DIR}")


if __name__ == "__main__":
    main()
