"""
Generate a video showing DINO PCA feature consistency across simultaneous
camera viewpoint rotation and object position shifts in LIBERO.

The camera smoothly orbits the scene while the bowl slides across the table.
DINO PCA colors stay consistent on each object throughout, demonstrating that
modern image features are multiview and position consistent.

Usage:
    export PYTHONPATH=/data/cameron/LIBERO:$PYTHONPATH
    export DINO_REPO_DIR=/data/cameron/keygrip/dinov3
    export DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth
    python generate_dino_pca_video.py --output_dir /data/cameron/para/.agents/reports/project_site/media/
"""

import argparse
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from pathlib import Path
from sklearn.decomposition import PCA

# ── LIBERO imports ──
from libero.libero.envs import OffScreenRenderEnv
from libero.libero import benchmark as bm_lib, get_libero_path
import h5py

# ── Constants ──
IMAGE_SIZE = 448
DINO_PATCH_SIZE = 16
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])


def load_dino_backbone(device):
    """Load frozen DINO ViT-S/16 backbone."""
    repo = os.environ.get("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
    weights = os.environ.get("DINO_WEIGHTS_PATH",
        "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

    sys.path.insert(0, repo)
    model = torch.hub.load(repo, "dinov3_vits16plus", source="local", weights=weights)
    model = model.to(device).eval()
    for p in model.parameters():
        p.requires_grad = False
    return model


def extract_dino_features(model, rgb_uint8, device):
    """Extract patch features from a single RGB image.

    Args:
        model: DINO model
        rgb_uint8: (H, W, 3) uint8 numpy array
        device: torch device

    Returns:
        patch_features: (H_p, W_p, C) numpy array of patch tokens
    """
    # Normalize to ImageNet stats
    img = rgb_uint8.astype(np.float32) / 255.0
    img = (img - IMAGENET_MEAN) / IMAGENET_STD
    img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float().to(device)

    with torch.no_grad():
        x_tokens, (H_p, W_p) = model.prepare_tokens_with_masks(img_tensor)
        for blk in model.blocks:
            rope_sincos = model.rope_embed(H=H_p, W=W_p) if model.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        x_tokens = model.norm(x_tokens)

        n_storage = model.n_storage_tokens if hasattr(model, 'n_storage_tokens') else 0
        patch_tokens = x_tokens[:, n_storage + 1:]  # skip CLS + storage
        patch_features = patch_tokens.reshape(H_p, W_p, -1)

    return patch_features.cpu().numpy(), H_p, W_p


def compute_pca_images(features_list, rgb_list, n_components=3):
    """Compute joint PCA across all frames and produce pure PCA images.

    Args:
        features_list: list of (H_p, W_p, C) feature arrays
        rgb_list: list of (H, W, 3) uint8 images (for resolution reference)
        n_components: PCA components (3 for RGB)

    Returns:
        list of (H, W, 3) uint8 PCA images (no blending, pure PCA)
    """
    H, W = rgb_list[0].shape[:2]
    H_p, W_p = features_list[0].shape[:2]

    # Stack all features for joint PCA
    all_feats = np.concatenate([f.reshape(-1, f.shape[-1]) for f in features_list], axis=0)

    pca = PCA(n_components=n_components)
    all_pca = pca.fit_transform(all_feats)

    # Normalize globally to [0, 1]
    pca_min = all_pca.min(axis=0)
    pca_max = all_pca.max(axis=0)
    pca_range = pca_max - pca_min
    pca_range[pca_range == 0] = 1.0
    all_pca = (all_pca - pca_min) / pca_range

    # Split back into per-frame
    n_patches = H_p * W_p
    pca_images = []
    for i in range(len(rgb_list)):
        pca_frame = all_pca[i * n_patches:(i + 1) * n_patches]
        pca_rgb = pca_frame.reshape(H_p, W_p, 3)

        # Upsample to image resolution
        pca_upsampled = cv2.resize(pca_rgb, (W, H), interpolation=cv2.INTER_LINEAR)
        pca_uint8 = np.clip(pca_upsampled * 255, 0, 255).astype(np.uint8)
        pca_images.append(pca_uint8)

    return pca_images


def _compute_cam_from_spherical(look_at, radius, default_dir_norm, phi_rad, theta_rad):
    """Compute camera position and quaternion from spherical offsets.

    Args:
        look_at: (3,) world point to look at
        radius: distance from look_at
        default_dir_norm: (3,) unit vector from look_at to default camera
        phi_rad: azimuth offset in radians
        theta_rad: elevation offset in radians

    Returns:
        cam_pos: (3,) camera position
        cam_quat: (4,) MuJoCo quaternion (w, x, y, z)
    """
    from scipy.spatial.transform import Rotation as R

    # Azimuth rotation around world Z
    cos_phi, sin_phi = np.cos(phi_rad), np.sin(phi_rad)
    rot_z = np.array([
        [cos_phi, -sin_phi, 0],
        [sin_phi,  cos_phi, 0],
        [0,        0,       1]
    ])
    rotated_dir = rot_z @ default_dir_norm

    # Elevation rotation
    right = np.cross(np.array([0, 0, 1.0]), rotated_dir)
    if np.linalg.norm(right) > 1e-6:
        right = right / np.linalg.norm(right)
        rot_elev = R.from_rotvec(theta_rad * right).as_matrix()
        rotated_dir = rot_elev @ rotated_dir

    cam_pos = look_at + radius * rotated_dir

    # Look-at quaternion (MuJoCo convention: -z = forward, y = up, x = right)
    forward = (look_at - cam_pos)
    forward = forward / np.linalg.norm(forward)
    world_up = np.array([0, 0, 1.0])
    right = np.cross(forward, world_up)
    if np.linalg.norm(right) < 1e-6:
        right = np.array([1, 0, 0.0])
    right = right / np.linalg.norm(right)
    up = np.cross(right, forward)
    up = up / np.linalg.norm(up)

    cam_mat = np.stack([right, up, -forward], axis=1)
    cam_rot = R.from_matrix(cam_mat)
    quat_xyzw = cam_rot.as_quat()  # scipy: [x, y, z, w]
    cam_quat = np.array([quat_xyzw[3], quat_xyzw[0],
                          quat_xyzw[1], quat_xyzw[2]])  # MuJoCo: [w, x, y, z]

    return cam_pos, cam_quat


def generate_smooth_trajectory(n_frames, env, sim, cam_id, state_0):
    """Generate a two-phase trajectory:
      Phase 1: Object slides left→right, camera stays at default.
      Phase 2: Object freezes at final position, camera orbits.

    Returns:
        trajectory: list of (dx, dy, cam_pos, cam_quat) tuples
        phase_boundary: index where phase 2 starts
    """
    from scipy.spatial.transform import Rotation as R

    # Get default camera parameters
    default_pos = sim.model.cam_pos[cam_id].copy()
    default_quat = sim.model.cam_quat[cam_id].copy()

    default_rot = R.from_quat([default_quat[1], default_quat[2],
                                default_quat[3], default_quat[0]])  # xyzw
    default_mat = default_rot.as_matrix()
    forward_dir = -default_mat[:, 2]

    # Compute look-at point
    table_z = 0.85
    if abs(forward_dir[2]) > 1e-6:
        t = (table_z - default_pos[2]) / forward_dir[2]
        look_at = default_pos + t * forward_dir
    else:
        look_at = default_pos + 0.5 * forward_dir
        look_at[2] = table_z

    radius = np.linalg.norm(default_pos - look_at)
    default_dir = default_pos - look_at
    default_dir_norm = default_dir / np.linalg.norm(default_dir)

    # Split frames: half for object motion, half for camera orbit
    n_phase1 = n_frames // 2
    n_phase2 = n_frames - n_phase1

    trajectory = []

    # ── Phase 1: Object moves, camera stays at default ──
    dy_range = np.linspace(-0.12, 0.12, n_phase1)
    dx_range = np.linspace(-0.04, 0.04, n_phase1)

    for i in range(n_phase1):
        trajectory.append((dx_range[i], dy_range[i], default_pos.copy(), default_quat.copy()))

    # ── Phase 2: Object frozen at final position, camera orbits ──
    final_dx = dx_range[-1]
    final_dy = dy_range[-1]

    phi_range = np.linspace(-35, 35, n_phase2) * np.pi / 180
    theta_range = np.linspace(3, 18, n_phase2) * np.pi / 180

    for i in range(n_phase2):
        cam_pos, cam_quat = _compute_cam_from_spherical(
            look_at, radius, default_dir_norm, phi_range[i], theta_range[i])
        trajectory.append((final_dx, final_dy, cam_pos, cam_quat))

    return trajectory, n_phase1


def render_frame(env, sim, cam_id, state_0, dx, dy, cam_pos, cam_quat):
    """Render a single frame with shifted object and camera.

    Returns:
        (H, W, 3) uint8 RGB image
    """
    # Save original camera
    orig_pos = sim.model.cam_pos[cam_id].copy()
    orig_quat = sim.model.cam_quat[cam_id].copy()

    # Shift object position in state
    state = state_0.copy()
    state[10] += dx    # bowl X
    state[11] += dy    # bowl Y
    state[38] += dx    # plate X
    state[39] += dy    # plate Y

    # Set camera
    sim.model.cam_pos[cam_id] = cam_pos
    sim.model.cam_quat[cam_id] = cam_quat

    # Set state and render
    obs = env.set_init_state(state)
    sim.forward()

    # Re-render after forward() to pick up camera changes
    obs = env.env._get_observations()
    rgb = np.asarray(obs["agentview_image"]).copy()
    rgb = np.flipud(rgb)

    # Restore camera
    sim.model.cam_pos[cam_id] = orig_pos
    sim.model.cam_quat[cam_id] = orig_quat
    sim.forward()

    return rgb


def hide_furniture_and_distractors(sim):
    """Hide furniture and distractor objects for clean visualization."""
    # Hide furniture
    for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            bid = sim.model.body_name2id(name)
            sim.model.body_pos[bid] = np.array([0, 0, -5.0])
        except Exception:
            pass

    # Hide distractor geometries
    distractor_body_names = [
        "akita_black_bowl_2_main",
        "cookies_1_main",
        "glazed_rim_porcelain_ramekin_1_main",
    ]
    distractor_bids = set()
    for name in distractor_body_names:
        try:
            distractor_bids.add(sim.model.body_name2id(name))
        except Exception:
            pass

    for geom_id in range(sim.model.ngeom):
        body_id = sim.model.geom_bodyid[geom_id]
        if body_id in distractor_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0

    sim.forward()


def add_text_overlay(frame, text, position="bottom", font_scale=0.7, thickness=2):
    """Add text overlay on frame."""
    h, w = frame.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX

    # Get text size
    (tw, th), baseline = cv2.getTextSize(text, font, font_scale, thickness)

    if position == "bottom":
        x = (w - tw) // 2
        y = h - 30
    elif position == "top":
        x = (w - tw) // 2
        y = th + 20
    else:
        x, y = position

    # Draw background rectangle
    cv2.rectangle(frame, (x - 8, y - th - 8), (x + tw + 8, y + baseline + 8),
                  (0, 0, 0), -1)
    # Draw text
    cv2.putText(frame, text, (x, y), font, font_scale, (255, 255, 255), thickness,
                cv2.LINE_AA)
    return frame


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str,
                        default="/data/cameron/para/.agents/reports/project_site/media/")
    parser.add_argument("--n_frames", type=int, default=90,
                        help="Total frames (split evenly: half object motion, half camera orbit)")
    parser.add_argument("--fps", type=int, default=15)
    parser.add_argument("--clean_scene", action="store_true", default=True,
                        help="Remove furniture and distractors")
    parser.add_argument("--device", type=str, default=None)
    args = parser.parse_args()

    if args.device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)
    print(f"Using device: {device}")

    # ── Load DINO backbone ──
    print("Loading DINO backbone...")
    dino = load_dino_backbone(device)
    print(f"  Embed dim: {dino.embed_dim}")

    # ── Initialize LIBERO environment ──
    print("Initializing LIBERO environment...")
    benchmark = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = benchmark.get_task(0)
    bddl_file = os.path.join(get_libero_path("bddl_files"),
                              task.problem_folder, task.bddl_file)

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=["agentview"],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    # ── Load initial state from first demo ──
    demo_path = os.path.join(get_libero_path("datasets"),
                              benchmark.get_task_demonstration(0))
    with h5py.File(demo_path, "r") as f:
        state_0 = np.array(f["data/demo_0/states"][0])
    print(f"  State shape: {state_0.shape}")

    # ── Clean scene ──
    if args.clean_scene:
        print("Hiding furniture and distractors...")
        hide_furniture_and_distractors(sim)

    # ── Get camera ID ──
    cam_id = sim.model.camera_name2id("agentview")

    # ── Generate two-phase trajectory ──
    print(f"Generating {args.n_frames}-frame trajectory (2 phases)...")
    trajectory, phase_boundary = generate_smooth_trajectory(
        args.n_frames, env, sim, cam_id, state_0)
    print(f"  Phase 1 (object motion): frames 0-{phase_boundary - 1}")
    print(f"  Phase 2 (camera orbit):  frames {phase_boundary}-{len(trajectory) - 1}")

    # ── Render all frames and extract features ──
    print("Rendering frames and extracting DINO features...")
    rgb_frames = []
    feature_maps = []

    for i, (dx, dy, cam_pos, cam_quat) in enumerate(trajectory):
        if (i + 1) % 10 == 0 or i == 0:
            print(f"  Frame {i + 1}/{len(trajectory)}")

        rgb = render_frame(env, sim, cam_id, state_0, dx, dy, cam_pos, cam_quat)
        rgb_frames.append(rgb)

        feats, H_p, W_p = extract_dino_features(dino, rgb, device)
        feature_maps.append(feats)

    # ── Compute joint PCA (pure PCA images, no blending) ──
    print("Computing joint PCA across all frames...")
    pca_images = compute_pca_images(feature_maps, rgb_frames)

    # ── Save video: RGB (left) | DINO PCA (right), side by side ──
    os.makedirs(args.output_dir, exist_ok=True)

    output_path = os.path.join(args.output_dir, "dino_pca_consistency.mp4")
    frame_h, frame_w = rgb_frames[0].shape[:2]
    combined_w = frame_w * 2

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(output_path, fourcc, args.fps, (combined_w, frame_h))

    for i in range(len(rgb_frames)):
        rgb_bgr = cv2.cvtColor(rgb_frames[i], cv2.COLOR_RGB2BGR)
        pca_bgr = cv2.cvtColor(pca_images[i], cv2.COLOR_RGB2BGR)

        # Phase-specific text
        if i < phase_boundary:
            phase_text = "Object position changes"
        else:
            phase_text = "Camera viewpoint changes"

        # Add labels
        rgb_labeled = add_text_overlay(rgb_bgr.copy(), "RGB", position="top", font_scale=0.6)
        pca_labeled = add_text_overlay(pca_bgr.copy(), "DINO PCA Features",
                                       position="top", font_scale=0.6)

        combined = np.hstack([rgb_labeled, pca_labeled])
        combined = add_text_overlay(combined, phase_text, position="bottom", font_scale=0.65)

        writer.write(combined)

    writer.release()
    print(f"Saved side-by-side video: {output_path}")

    # Also save PCA-only video
    pca_only_path = os.path.join(args.output_dir, "dino_pca_only.mp4")
    writer2 = cv2.VideoWriter(pca_only_path, fourcc, args.fps, (frame_w, frame_h))
    for pca_img in pca_images:
        writer2.write(cv2.cvtColor(pca_img, cv2.COLOR_RGB2BGR))
    writer2.release()
    print(f"Saved PCA-only video: {pca_only_path}")

    # ── Re-encode to H.264 ──
    for path in [output_path, pca_only_path]:
        h264_path = path.replace(".mp4", "_h264.mp4")
        ret = os.system(
            f'ffmpeg -y -i "{path}" -c:v libx264 -preset ultrafast -crf 23 '
            f'-movflags +faststart "{h264_path}" 2>/dev/null'
        )
        if ret == 0:
            os.replace(h264_path, path)
            print(f"  Re-encoded to H.264: {path}")

    env.close()
    print("Done!")


if __name__ == "__main__":
    main()
