"""
Warp per-height heatmaps from a saved inference (view 1) onto a second camera view (view 2)
using known camera poses. Start with the min-height heatmap overlaid on the second view RGB.
"""
import argparse
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

# Add parent for utils
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))

from utils import recover_3d_from_direct_keypoint_and_height, project_3d_to_2d, unproject_2d_to_ray

IMAGE_SIZE = 448
N_HEIGHT_BINS = 32

INFERENCE_PATH = "volume_dino_tracks/scratch/vistest_live_test_model_ik_data.pth"
SECOND_VIEW_PATH = "volume_dino_tracks/scratch/second_view.pth"


def load_inference(path: str):
    """Load saved inference: rgb, camera_pose, cam_K, pred_logits (volume), minmax_height."""
    data = torch.load(path, map_location="cpu", weights_only=False)
    rgb = data["rgb"]
    if isinstance(rgb, torch.Tensor):
        rgb = rgb.numpy()
    camera_pose = data["camera_pose"]
    if isinstance(camera_pose, torch.Tensor):
        camera_pose = camera_pose.numpy()
    cam_K = data["cam_K"]
    if isinstance(cam_K, torch.Tensor):
        cam_K = cam_K.numpy()
    volume_logits = data.get("pred_logits", data.get("volume_logits"))
    if volume_logits is None:
        raise KeyError("Inference file must contain 'pred_logits' or 'volume_logits'")
    if isinstance(volume_logits, torch.Tensor):
        volume_logits = volume_logits.cpu().numpy()
    minmax_height = data.get("minmax_height", (0.03, 0.17))
    if isinstance(minmax_height, (list, tuple)):
        minmax_height = tuple(float(x) for x in minmax_height)
    return {
        "rgb": rgb,
        "camera_pose": camera_pose,
        "cam_K": cam_K,
        "volume_logits": volume_logits,
        "minmax_height": minmax_height,
    }


def load_second_view(path: str):
    """Load second view: rgb, cam_K, camera_pose."""
    data = torch.load(path, map_location="cpu", weights_only=False)
    rgb = data["rgb"]
    if isinstance(rgb, torch.Tensor):
        rgb = rgb.numpy()
    cam_K = data["cam_K"]
    if isinstance(cam_K, torch.Tensor):
        cam_K = cam_K.numpy()
    camera_pose = data["camera_pose"]
    if isinstance(camera_pose, torch.Tensor):
        camera_pose = camera_pose.numpy()
    return {"rgb": rgb, "cam_K": cam_K, "camera_pose": camera_pose}


def softmax_volume(volume_logits, timestep=-1):
    """Softmax over the full volume at one timestep (like inference).
    volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, H, W).
    Returns (N_HEIGHT_BINS, H, W) float, probabilities summing to 1 over the volume.
    """
    vol = volume_logits[0, timestep]  # (N_HEIGHT_BINS, H, W)
    vol = np.asarray(vol, dtype=np.float64)
    vol_flat = vol.ravel()
    vol_flat = vol_flat - vol_flat.max()  # for numerical stability
    exp_vol = np.exp(vol_flat)
    probs_flat = exp_vol / exp_vol.sum()
    return probs_flat.reshape(vol.shape).astype(np.float32)


def alpha_composite_layer(backdrop, layer_rgb, layer_alpha):
    """Composite layer on top of backdrop: out = backdrop * (1 - alpha) + layer_rgb * alpha.
    backdrop: (H, W, 3), layer_rgb: (H, W, 3), layer_alpha: (H, W) or (H, W, 1).
    """
    a = layer_alpha if layer_alpha.ndim == 3 else layer_alpha[:, :, np.newaxis]
    return np.clip(backdrop * (1 - a) + layer_rgb * a, 0, 1).astype(np.float32)


def _recover_3d_at_plane(point_2d, plane_axis, plane_value, camera_pose, cam_K):
    """Ray from camera through point_2d, intersected with plane point[plane_axis] = plane_value (world frame).
    Use plane_axis=1 for Y-up (bottom/top faces); utils use axis 2 for Z-up.
    """
    cam_pos, ray = unproject_2d_to_ray(point_2d, camera_pose, cam_K)
    if abs(ray[plane_axis]) < 1e-9:
        return None
    t = (plane_value - cam_pos[plane_axis]) / ray[plane_axis]
    if t < 0:
        return None
    return cam_pos + t * ray


def get_volume_boundary_edges_in_view2(
    camera_pose_v1,
    cam_K_v1,
    camera_pose_v2,
    cam_K_v2,
    W1,
    H1,
    min_height_m,
    max_height_m,
    height_axis=2,
):
    """Get the 12 edges of the volume box (view 1 frustum between min/max height) projected into view 2.
    Volume: world[height_axis] in [min_height_m, max_height_m]; x,z from view 1 image.
    height_axis=2 (default) matches utils Z-up so lines show; try 1 for Y-up if volume looked wrong.
    Returns list of ((x0,y0), (x1,y1)) in view 2 pixels for each edge; skips edges with invalid projection.
    """
    corners_px = np.array([[0, 0], [W1, 0], [W1, H1], [0, H1]], dtype=np.float64)
    world_pts = []
    for height_m in (min_height_m, max_height_m):
        for (u, v) in corners_px:
            pt_3d = _recover_3d_at_plane(
                np.array([u, v], dtype=np.float64), height_axis, height_m, camera_pose_v1, cam_K_v1
            )
            if pt_3d is None:
                return []
            world_pts.append(pt_3d)
    # bottom face: 0-1, 1-2, 2-3, 3-0; top face: 4-5, 5-6, 6-7, 7-4; verticals: 0-4, 1-5, 2-6, 3-7
    edges_idx = [
        (0, 1), (1, 2), (2, 3), (3, 0),
        (4, 5), (5, 6), (6, 7), (7, 4),
        (0, 4), (1, 5), (2, 6), (3, 7),
    ]
    world_pts = np.array(world_pts)
    result = []
    for i, j in edges_idx:
        pa = project_3d_to_2d(world_pts[i], camera_pose_v2, cam_K_v2)
        pb = project_3d_to_2d(world_pts[j], camera_pose_v2, cam_K_v2)
        if pa is not None and pb is not None:
            result.append((tuple(pa), tuple(pb)))
    return result


def compute_plane_homography_view1_to_view2(
    camera_pose_v1,
    cam_K_v1,  # intrinsics for view 1 at heatmap resolution (W1, H1)
    camera_pose_v2,
    cam_K_v2,
    height_m,
    H1,
    W1,
    H2,
    W2,
    grid_step=16,
):
    """Compute 3x3 homography from view 1 image to view 2 image for the plane at height_m.
    Uses point correspondences: unproject view1 pixels at height_m -> 3D -> project to view2.
    Returns H (3x3) such that x2 ~ H @ x1 (homogeneous), or None if too few inliers.
    """
    src_pts = []
    dst_pts = []
    for i in range(0, H1, grid_step):
        for j in range(0, W1, grid_step):
            pt_2d_v1 = np.array([j, i], dtype=np.float64)
            pt_3d = recover_3d_from_direct_keypoint_and_height(
                pt_2d_v1, height_m, camera_pose_v1, cam_K_v1
            )
            if pt_3d is None:
                continue
            pt_2d_v2 = project_3d_to_2d(pt_3d, camera_pose_v2, cam_K_v2)
            if pt_2d_v2 is None:
                continue
            u2, v2 = pt_2d_v2[0], pt_2d_v2[1]
            if 0 <= u2 < W2 and 0 <= v2 < H2:
                src_pts.append([j, i])
                dst_pts.append([u2, v2])
    if len(src_pts) < 4:
        return None
    src_pts = np.array(src_pts, dtype=np.float32)
    dst_pts = np.array(dst_pts, dtype=np.float32)
    H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, ransacReprojThreshold=5.0)
    return H


def warp_heatmap_to_view2(
    heatmap_v1,  # (H1, W1) e.g. (448, 448)
    camera_pose_v1,
    cam_K_v1,
    camera_pose_v2,
    cam_K_v2,
    height_m,
    H2,
    W2,
    grid_step=16,
):
    """Warp heatmap from view 1 to view 2 using planar homography (plane at height_m) and cv2.warpPerspective."""
    H1, W1 = heatmap_v1.shape
    H = compute_plane_homography_view1_to_view2(
        camera_pose_v1, cam_K_v1, camera_pose_v2, cam_K_v2,
        height_m, H1, W1, H2, W2, grid_step=grid_step,
    )
    if H is None:
        print("Warning: could not compute homography (too few valid correspondences), returning zeros")
        return np.zeros((H2, W2), dtype=np.float32)
    # warpPerspective expects (W, H) for size
    heatmap_v2 = cv2.warpPerspective(
        heatmap_v1, H, (W2, H2),
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=0,
    )
    return heatmap_v2.astype(np.float32)


def main():
    parser = argparse.ArgumentParser(description="Warp inference min-height heatmap onto second view")
    parser.add_argument("--inference", type=str, default=INFERENCE_PATH, help="Path to saved inference .pth")
    parser.add_argument("--second_view", type=str, default=SECOND_VIEW_PATH, help="Path to second view .pth")
    parser.add_argument("--grid_step", type=int, default=16, help="Grid step for homography correspondences (larger = faster)")
    parser.add_argument("--out", type=str, default="volume_dino_tracks/scratch/heatmap_vis_second_view.png")
    parser.add_argument("--height_axis", type=int, default=2, choices=[0, 1, 2], help="World axis for height (0=X, 1=Y, 2=Z). Default 2 (Z-up) so volume lines show; try 1 for Y-up.")
    args = parser.parse_args()

    # Load data
    inf = load_inference(args.inference)
    v2 = load_second_view(args.second_view)

    rgb1 = inf["rgb"]
    
    #plt.imshow(rgb1);plt.show()
    H1_rgb, W1_rgb = rgb1.shape[:2]
    camera_pose_v1 = inf["camera_pose"]
    cam_K_v1_rgb = inf["cam_K"]  # intrinsics at rgb1 resolution

    # Intrinsics for view 1 at heatmap resolution (448)
    cam_K_v1_448 = np.eye(3, dtype=np.float64)
    cam_K_v1_448[0, 0] = cam_K_v1_rgb[0, 0] * (IMAGE_SIZE / W1_rgb)
    cam_K_v1_448[0, 2] = cam_K_v1_rgb[0, 2] * (IMAGE_SIZE / W1_rgb)
    cam_K_v1_448[1, 1] = cam_K_v1_rgb[1, 1] * (IMAGE_SIZE / H1_rgb)
    cam_K_v1_448[1, 2] = cam_K_v1_rgb[1, 2] * (IMAGE_SIZE / H1_rgb)

    rgb2 = v2["rgb"]
    H2, W2 = rgb2.shape[:2]
    camera_pose_v2 = v2["camera_pose"]
    cam_K_v2 = v2["cam_K"]

    min_height_m = float(inf["minmax_height"][0])
    max_height_m = float(inf["minmax_height"][1])
    n_height_bins = inf["volume_logits"].shape[2]  # may differ from N_HEIGHT_BINS

    anim_pics_dir = Path("volume_dino_tracks/scratch/anim_pics")
    anim_pics_dir.mkdir(parents=True, exist_ok=True)

    # View 1 RGB for left pane (washed, same style as middle)
    rgb1_float = rgb1.astype(np.float32) / 255.0
    if rgb1_float.ndim == 2:
        rgb1_float = np.stack([rgb1_float] * 3, axis=-1)
    rgb1_float = rgb1_float * 0.5 + 0.55 * 0.5

    # Precompute keypoint in view 2 for every timestep (for drawing all keypoints up to t on volume)
    n_timesteps = inf["volume_logits"].shape[1]
    height_values_global = np.linspace(min_height_m, max_height_m, n_height_bins, dtype=np.float64)
    all_kp_2d_v2 = []
    for t in range(n_timesteps):
        vol_logits_t = inf["volume_logits"][0, t]
        if hasattr(vol_logits_t, "numpy"):
            vol_logits_t = vol_logits_t.numpy()
        vol_logits_t = np.asarray(vol_logits_t)
        flat_idx = np.argmax(vol_logits_t)
        h_bin, row, col = np.unravel_index(flat_idx, vol_logits_t.shape)
        height_m_kp = height_values_global[h_bin]
        pt_3d_kp = recover_3d_from_direct_keypoint_and_height(
            np.array([float(col), float(row)], dtype=np.float64),
            height_m_kp,
            camera_pose_v1,
            cam_K_v1_448,
        )
        kp_2d = project_3d_to_2d(pt_3d_kp, camera_pose_v2, cam_K_v2) if pt_3d_kp is not None else None
        all_kp_2d_v2.append(kp_2d)

    # Softmax over full volume at last timestep (like inference)
    for t_ in range(n_timesteps):
        vol_probs = softmax_volume(inf["volume_logits"], timestep=t_)  # (n_height_bins, H, W)
        # Normalize so max over volume is 1 (softmax values are tiny; scale for visibility)
        vol_max = float(vol_probs.max())
        if vol_max > 0:
            vol_probs = vol_probs / vol_max
        # Power scaling (gamma < 1) to boost small values so they're visible
        gamma = 0.35
        vol_probs = np.power(np.clip(vol_probs, 0, None), gamma)*1.2
        height_values = np.linspace(min_height_m, max_height_m, n_height_bins, dtype=np.float64)

        # Warp each height layer to view 2 (each layer = slice of volume at one height)
        warped_stack = np.zeros((n_height_bins, H2, W2), dtype=np.float32)
        for h in range(n_height_bins):
            warped_stack[h] = warp_heatmap_to_view2(
                vol_probs[h],
                camera_pose_v1,
                cam_K_v1_448,
                camera_pose_v2,
                cam_K_v2,
                height_values[h],
                H2,
                W2,
                grid_step=args.grid_step,
            )
        rgb2_float = rgb2.astype(np.float32) / 255.0
        if rgb2_float.ndim == 2: rgb2_float = np.stack([rgb2_float] * 3, axis=-1)
        # Wash out background toward gray so the heatmap is easier to see
        gray = 0.55
        rgb2_float = rgb2_float * 0.5 + gray * 0.5
        cmap = plt.get_cmap("coolwarm")  # blue at 0, red at 1
        alpha_min, alpha_max = 0.25, 0.95

        volume_edges = get_volume_boundary_edges_in_view2(
            camera_pose_v1,
            cam_K_v1_448,
            camera_pose_v2,
            cam_K_v2,
            IMAGE_SIZE,
            IMAGE_SIZE,
            min_height_m,
            max_height_m,
            height_axis=args.height_axis,
        )

        kp_2d_v2 = all_kp_2d_v2[t_]

        def draw_keypoint_on_bgr(overlay_bgr, kp_2d, color=(0, 0, 255), radius=16, thickness=4):
            if kp_2d is None:
                return
            cx, cy = int(round(kp_2d[0])), int(round(kp_2d[1]))
            if not (0 <= cx < overlay_bgr.shape[1] and 0 <= cy < overlay_bgr.shape[0]):
                return
            cv2.circle(overlay_bgr, (cx, cy), radius, color, thickness)
            cv2.drawMarker(overlay_bgr, (cx, cy), color, cv2.MARKER_CROSS, radius * 2, thickness)

        # Per-layer imshow and save to anim_pics for first timestep
        if t_ == 0:
            np.save(str(anim_pics_dir / "height_values.npy"), height_values)
            rgb1_uint8 = (np.clip(rgb1_float, 0, 1) * 255).astype(np.uint8)
            cv2.imwrite(str(anim_pics_dir / "rgb.png"), cv2.cvtColor(rgb1_uint8, cv2.COLOR_RGB2BGR))
            for h in range(n_height_bins):
                # Middle pane: view-1-aligned raw heatmap (resize to view 1 res, no warping)
                heatmap_v1 = cv2.resize(vol_probs[h], (W1_rgb, H1_rgb), interpolation=cv2.INTER_LINEAR)
                v = np.clip(heatmap_v1, 0, 1)
                layer_rgb = cmap(v)[:, :, :3].astype(np.float32)
                alpha = np.where(heatmap_v1 > 0, alpha_min + (alpha_max - alpha_min) * v, 0.0)
                layer_overlay_v1 = alpha_composite_layer(rgb1_float, layer_rgb, alpha)
                heatmap_only_uint8 = (np.clip(layer_overlay_v1, 0, 1) * 255).astype(np.uint8)
                cv2.imwrite(str(anim_pics_dir / f"t0_heatmap_h{h:02d}.png"), cv2.cvtColor(heatmap_only_uint8, cv2.COLOR_RGB2BGR))

                # Right pane (view 2): warped layer overlay
                v = np.clip(warped_stack[h], 0, 1)
                layer_rgb = cmap(v)[:, :, :3].astype(np.float32)
                alpha = np.where(warped_stack[h] > 0, alpha_min + (alpha_max - alpha_min) * v, 0.0)
                layer_overlay = alpha_composite_layer(rgb2_float, layer_rgb, alpha)

                overlay_uint8 = (np.clip(layer_overlay, 0, 1) * 255).astype(np.uint8)
                overlay_bgr = cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR)
                for (pa, pb) in volume_edges:
                    pt_a = (int(round(pa[0])), int(round(pa[1])))
                    pt_b = (int(round(pb[0])), int(round(pb[1])))
                    cv2.line(overlay_bgr, pt_a, pt_b, (0, 0, 0), thickness=2)
                for i in range(t_ + 1):
                    draw_keypoint_on_bgr(overlay_bgr, all_kp_2d_v2[i])
                cv2.imwrite(str(anim_pics_dir / f"t0_volume_h{h:02d}.png"), overlay_bgr)
                layer_overlay = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

                plt.imshow(layer_overlay)
                plt.title(f"t={t_} height layer {h}/{n_height_bins} (z={height_values[h]*1000:.0f}mm)")
                plt.axis("off")
                plt.tight_layout()
                #plt.show()
                plt.close()

        # Max over height in view 1 (for middle pane anim): raw heatmaps, no warping
        max_v1 = vol_probs.max(axis=0)  # (448, 448)
        if t_ > 0:
            # Image spatial softmax over the 2D map for other timesteps
            flat = max_v1.ravel().astype(np.float64)
            flat = flat - flat.max()
            exp_flat = np.exp(flat)
            max_v1 = (exp_flat / exp_flat.sum()).reshape(max_v1.shape).astype(np.float32)
        max_v1_resized = cv2.resize(max_v1, (W1_rgb, H1_rgb), interpolation=cv2.INTER_LINEAR)
        v_norm = np.clip(max_v1_resized / (max_v1_resized.max() + 1e-9), 0, 1)
        heat_rgb_v1 = cmap(v_norm)[:, :, :3].astype(np.float32)
        alpha_v1 = np.where(max_v1_resized > 0, alpha_min + (alpha_max - alpha_min) * v_norm, 0.0)
        max_heatmap_overlay_v1 = alpha_composite_layer(rgb1_float, heat_rgb_v1, alpha_v1)
        max_heatmap_uint8 = (np.clip(max_heatmap_overlay_v1, 0, 1) * 255).astype(np.uint8)
        cv2.imwrite(str(anim_pics_dir / f"t{t_}_max_heatmap.png"), cv2.cvtColor(max_heatmap_uint8, cv2.COLOR_RGB2BGR))

        # Max projection in view 2 (for display and right pane): warped
        proj_max = warped_stack.max(axis=0)  # (H2, W2)
        v = np.clip(proj_max, 0, 1)
        heat_rgb = cmap(v)[:, :, :3].astype(np.float32)
        alpha = np.where(proj_max > 0, alpha_min + (alpha_max - alpha_min) * v, 0.0)
        overlay = alpha_composite_layer(rgb2_float, heat_rgb, alpha)

        # Draw volume boundaries in black (min/max height + view 1 image frustum at heatmap res)
        overlay_uint8 = (np.clip(overlay, 0, 1) * 255).astype(np.uint8)
        overlay_bgr = cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR)
        for (pa, pb) in volume_edges:
            pt_a = (int(round(pa[0])), int(round(pa[1])))
            pt_b = (int(round(pb[0])), int(round(pb[1])))
            cv2.line(overlay_bgr, pt_a, pt_b, (0, 0, 0), thickness=2)
        if t_ == 0:
            cv2.imwrite(str(anim_pics_dir / "t0_volume_nokp.png"), overlay_bgr.copy())
        for i in range(t_ + 1):
            draw_keypoint_on_bgr(overlay_bgr, all_kp_2d_v2[i])
        cv2.imwrite(str(anim_pics_dir / f"t{t_}_volume_kp.png"), overlay_bgr)
        overlay = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        ax.imshow(overlay)
        ax.set_title("Volume heatmaps (softmax) warped onto second view — green = argmax 3D keypoint")
        ax.axis("off")
        plt.tight_layout()
        #plt.show()
        plt.close()


if __name__ == "__main__":
    main()
