m""
Visualize the 3D volume from a saved inference as a point cloud in viser.
Uses the same inference tensor (volume_logits, camera_pose, cam_K, minmax_height);
no second view or heatmap warping.
"""
import argparse
from pathlib import Path
import time

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch

import os
import sys
sys.path.insert(0, os.path.dirname(__file__))

from utils import recover_3d_from_direct_keypoint_and_height

IMAGE_SIZE = 448
N_HEIGHT_BINS = 32

INFERENCE_PATH = "volume_dino_tracks/scratch/vistest_live_test_model_ik_data.pth"


def load_inference(path: str):
    """Load saved inference: rgb, camera_pose, cam_K, pred_logits (volume), minmax_height."""
    data = torch.load(path, map_location="cpu", weights_only=False)
    rgb = data["rgb"]
    if isinstance(rgb, torch.Tensor):
        rgb = rgb.numpy()
    camera_pose = data["camera_pose"]
    if isinstance(camera_pose, torch.Tensor):
        camera_pose = camera_pose.numpy()
    cam_K = data["cam_K"]
    if isinstance(cam_K, torch.Tensor):
        cam_K = cam_K.numpy()
    volume_logits = data.get("pred_logits", data.get("volume_logits"))
    if volume_logits is None:
        raise KeyError("Inference file must contain 'pred_logits' or 'volume_logits'")
    if isinstance(volume_logits, torch.Tensor):
        volume_logits = volume_logits.cpu().numpy()
    minmax_height = data.get("minmax_height", (0.03, 0.17))
    if isinstance(minmax_height, (list, tuple)):
        minmax_height = tuple(float(x) for x in minmax_height)
    return {
        "rgb": rgb,
        "camera_pose": camera_pose,
        "cam_K": cam_K,
        "volume_logits": volume_logits,
        "minmax_height": minmax_height,
    }


def softmax_volume(volume_logits, timestep=-1):
    """Softmax over the full volume at one timestep.
    volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, H, W).
    Returns (N_HEIGHT_BINS, H, W) float, probabilities summing to 1 over the volume.
    """
    vol = volume_logits[0, timestep]  # (N_HEIGHT_BINS, H, W)
    vol = np.asarray(vol, dtype=np.float64)
    vol_flat = vol.ravel()
    vol_flat = vol_flat - vol_flat.max()
    exp_vol = np.exp(vol_flat)
    probs_flat = exp_vol / exp_vol.sum()
    return probs_flat.reshape(vol.shape).astype(np.float32)


def volume_to_world_points_and_values(
    vol_values,  # (N_HEIGHT_BINS, H, W) — logits or probs
    height_values,  # (N_HEIGHT_BINS,) in meters
    camera_pose,
    cam_K,
    step_xy=7,
    step_h=1,
):
    """Sample volume on a grid; convert each (col, row, height) to world 3D.
    Returns xs, ys, zs (world coords), values (same length).
    """
    n_h, H, W = vol_values.shape
    xs, ys, zs, values = [], [], [], []
    for row in range(0, H, step_xy):
        for col in range(0, W, step_xy):
            for h_idx in range(0, n_h, step_h):
                v = float(vol_values[h_idx, row, col])
                height_m = float(height_values[h_idx])
                pt_3d = recover_3d_from_direct_keypoint_and_height(
                    np.array([float(col), float(row)], dtype=np.float64),
                    height_m,
                    camera_pose,
                    cam_K,
                )
                if pt_3d is None:
                    continue
                xs.append(pt_3d[0])
                ys.append(pt_3d[1])
                zs.append(pt_3d[2])
                values.append(v)
    return np.array(xs), np.array(ys), np.array(zs), np.array(values)


def values_norm_to_rgb(values_norm, cmap_name="coolwarm"):
    """Map normalized [0,1] values to RGB (N, 3) in [0,1] using matplotlib colormap."""
    cmap = plt.get_cmap(cmap_name)
    rgb = cmap(np.clip(values_norm, 0.0, 1.0).astype(np.float32))[:, :3].astype(np.float32)
    return rgb


def main():
    parser = argparse.ArgumentParser(description="Visualize inference volume as 3D point cloud in viser")
    parser.add_argument("--inference", type=str, default=INFERENCE_PATH, help="Path to saved inference .pth")
    parser.add_argument("--timestep", type=int, default=-1, help="Volume timestep to visualize (-1 = last)")
    parser.add_argument("--step_xy", type=int, default=13, help="Spatial downsampling (pixels)")
    parser.add_argument("--step_h", type=int, default=1, help="Height bin step (1 = all bins)")
    parser.add_argument("--use_softmax", action="store_true", help="Use softmax over volume; default is raw logits")
    parser.add_argument("--gamma", type=float, default=0.4, help="Power scaling when using softmax")
    parser.add_argument("--point_size", type=float, default=0.008, help="Viser point cloud point size (meters)")
    args = parser.parse_args()

    import viser

    inf = load_inference(args.inference)
    rgb = inf["rgb"]
    H1_rgb, W1_rgb = rgb.shape[:2]
    camera_pose = inf["camera_pose"]
    cam_K_rgb = inf["cam_K"]
    cam_K = np.eye(3, dtype=np.float64)
    cam_K[0, 0] = cam_K_rgb[0, 0] * (IMAGE_SIZE / W1_rgb)
    cam_K[0, 2] = cam_K_rgb[0, 2] * (IMAGE_SIZE / W1_rgb)
    cam_K[1, 1] = cam_K_rgb[1, 1] * (IMAGE_SIZE / H1_rgb)
    cam_K[1, 2] = cam_K_rgb[1, 2] * (IMAGE_SIZE / H1_rgb)

    min_height_m = float(inf["minmax_height"][0])
    max_height_m = float(inf["minmax_height"][1])
    n_height_bins = inf["volume_logits"].shape[2]
    height_values = np.linspace(min_height_m, max_height_m, n_height_bins, dtype=np.float64)

    # Volume at chosen timestep: raw logits (default) or softmax
    vol_logits_t = inf["volume_logits"][0, args.timestep]
    if hasattr(vol_logits_t, "numpy"):
        vol_logits_t = vol_logits_t.numpy()
    vol_logits_t = np.asarray(vol_logits_t, dtype=np.float64)
    if args.use_softmax:
        vol_values = softmax_volume(inf["volume_logits"], timestep=args.timestep)
        vol_max = float(vol_values.max())
        if vol_max > 0:
            vol_values = vol_values / vol_max
        vol_values = np.power(np.clip(vol_values, 0, None), args.gamma)
    else:
        vol_values = vol_logits_t

    # Sample and convert to world 3D
    xs, ys, zs, values = volume_to_world_points_and_values(
        vol_values, height_values, camera_pose, cam_K,
        step_xy=args.step_xy, step_h=args.step_h,
    )
    if len(xs) == 0:
        print("No valid 3D points (all behind camera?). Try smaller step_xy.")
        return

    # Normalize values to [0, 1] for color (blue=low, red=high)
    v_min, v_max = values.min(), values.max()
    values_norm = (values - v_min) / (v_max - v_min + 1e-9)
    colors_rgb = values_norm_to_rgb(values_norm)

    # Argmax keypoint in 3D (always from raw logits)
    flat_idx = np.argmax(vol_logits_t)
    h_bin, row, col = np.unravel_index(flat_idx, vol_values.shape)
    height_m_kp = height_values[h_bin]
    pt_3d_kp = recover_3d_from_direct_keypoint_and_height(
        np.array([float(col), float(row)], dtype=np.float64),
        height_m_kp,
        camera_pose,
        cam_K,
    )

    # Viser server and point cloud
    server = viser.ViserServer()
    points = np.stack([xs, ys, zs], axis=1).astype(np.float32)
    server.scene.add_point_cloud(
        "/volume_point_cloud",
        points=points,
        colors=colors_rgb,
        point_size=args.point_size,
    )

    # Argmax keypoint as a single lime point (larger size)
    if pt_3d_kp is not None:
        server.scene.add_point_cloud(
            "/argmax_keypoint",
            points=pt_3d_kp.reshape(1, 3).astype(np.float32),
            colors=np.array([[0.0, 1.0, 0.0]], dtype=np.float32),  # lime
            point_size=0.02,
        )

    print(f"✓ Volume point cloud: {len(points)} points (timestep {args.timestep})")
    print("✓ Viser server running. Open the URL in the browser. Ctrl+C to exit.")
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        print("\nExiting...")


if __name__ == "__main__":
    main()
