"""
Visualize the 3D volume from a saved inference in a matplotlib 3D plot.
Uses the same inference tensor (volume_logits, camera_pose, cam_K, minmax_height);
no second view or heatmap warping.
"""
import argparse
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import torch

import os
import sys
sys.path.insert(0, os.path.dirname(__file__))

from utils import recover_3d_from_direct_keypoint_and_height

IMAGE_SIZE = 448
N_HEIGHT_BINS = 32

INFERENCE_PATH = "volume_dino_tracks/scratch/vistest_live_test_model_ik_data.pth"


def load_inference(path: str):
    """Load saved inference: rgb, camera_pose, cam_K, pred_logits (volume), minmax_height."""
    data = torch.load(path, map_location="cpu", weights_only=False)
    rgb = data["rgb"]
    if isinstance(rgb, torch.Tensor):
        rgb = rgb.numpy()
    camera_pose = data["camera_pose"]
    if isinstance(camera_pose, torch.Tensor):
        camera_pose = camera_pose.numpy()
    cam_K = data["cam_K"]
    if isinstance(cam_K, torch.Tensor):
        cam_K = cam_K.numpy()
    volume_logits = data.get("pred_logits", data.get("volume_logits"))
    if volume_logits is None:
        raise KeyError("Inference file must contain 'pred_logits' or 'volume_logits'")
    if isinstance(volume_logits, torch.Tensor):
        volume_logits = volume_logits.cpu().numpy()
    minmax_height = data.get("minmax_height", (0.03, 0.17))
    if isinstance(minmax_height, (list, tuple)):
        minmax_height = tuple(float(x) for x in minmax_height)
    return {
        "rgb": rgb,
        "camera_pose": camera_pose,
        "cam_K": cam_K,
        "volume_logits": volume_logits,
        "minmax_height": minmax_height,
    }


def softmax_volume(volume_logits, timestep=-1):
    """Softmax over the full volume at one timestep.
    volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, H, W).
    Returns (N_HEIGHT_BINS, H, W) float, probabilities summing to 1 over the volume.
    """
    vol = volume_logits[0, timestep]  # (N_HEIGHT_BINS, H, W)
    vol = np.asarray(vol, dtype=np.float64)
    vol_flat = vol.ravel()
    vol_flat = vol_flat - vol_flat.max()
    exp_vol = np.exp(vol_flat)
    probs_flat = exp_vol / exp_vol.sum()
    return probs_flat.reshape(vol.shape).astype(np.float32)


def volume_to_world_points_and_values(
    vol_values,  # (N_HEIGHT_BINS, H, W) — logits or probs
    height_values,  # (N_HEIGHT_BINS,) in meters
    camera_pose,
    cam_K,
    step_xy=7,
    step_h=1,
):
    """Sample volume on a grid; convert each (col, row, height) to world 3D.
    Returns xs, ys, zs (world coords), values (same length).
    """
    n_h, H, W = vol_values.shape
    xs, ys, zs, values = [], [], [], []
    for row in range(0, H, step_xy):
        for col in range(0, W, step_xy):
            for h_idx in range(0, n_h, step_h):
                v = float(vol_values[h_idx, row, col])
                height_m = float(height_values[h_idx])
                pt_3d = recover_3d_from_direct_keypoint_and_height(
                    np.array([float(col), float(row)], dtype=np.float64),
                    height_m,
                    camera_pose,
                    cam_K,
                )
                if pt_3d is None:
                    continue
                xs.append(pt_3d[0])
                ys.append(pt_3d[1])
                zs.append(pt_3d[2])
                values.append(v)
    return np.array(xs), np.array(ys), np.array(zs), np.array(values)


def main():
    parser = argparse.ArgumentParser(description="Visualize inference volume as 3D matplotlib plot")
    parser.add_argument("--inference", type=str, default=INFERENCE_PATH, help="Path to saved inference .pth")
    parser.add_argument("--out", type=str, default="volume_dino_tracks/scratch/volume_3d_vis.png", help="Output figure path")
    parser.add_argument("--timestep", type=int, default=-1, help="Volume timestep to visualize (-1 = last)")
    parser.add_argument("--step_xy", type=int, default=13, help="Spatial downsampling (pixels)")
    parser.add_argument("--step_h", type=int, default=1, help="Height bin step (1 = all bins)")
    parser.add_argument("--use_softmax", action="store_true", help="Use softmax over volume; default is raw logits")
    parser.add_argument("--gamma", type=float, default=0.4, help="Power scaling when using softmax")
    args = parser.parse_args()

    inf = load_inference(args.inference)
    rgb = inf["rgb"]
    H1_rgb, W1_rgb = rgb.shape[:2]
    camera_pose = inf["camera_pose"]
    cam_K_rgb = inf["cam_K"]
    cam_K = np.eye(3, dtype=np.float64)
    cam_K[0, 0] = cam_K_rgb[0, 0] * (IMAGE_SIZE / W1_rgb)
    cam_K[0, 2] = cam_K_rgb[0, 2] * (IMAGE_SIZE / W1_rgb)
    cam_K[1, 1] = cam_K_rgb[1, 1] * (IMAGE_SIZE / H1_rgb)
    cam_K[1, 2] = cam_K_rgb[1, 2] * (IMAGE_SIZE / H1_rgb)

    min_height_m = float(inf["minmax_height"][0])
    max_height_m = float(inf["minmax_height"][1])
    n_height_bins = inf["volume_logits"].shape[2]
    height_values = np.linspace(min_height_m, max_height_m, n_height_bins, dtype=np.float64)

    # Volume at chosen timestep: raw logits (default) or softmax
    vol_logits_t = inf["volume_logits"][0, args.timestep]
    if hasattr(vol_logits_t, "numpy"):
        vol_logits_t = vol_logits_t.numpy()
    vol_logits_t = np.asarray(vol_logits_t, dtype=np.float64)
    if args.use_softmax:
        vol_values = softmax_volume(inf["volume_logits"], timestep=args.timestep)
        vol_max = float(vol_values.max())
        if vol_max > 0:
            vol_values = vol_values / vol_max
        vol_values = np.power(np.clip(vol_values, 0, None), args.gamma)
    else:
        vol_values = vol_logits_t

    # Sample and convert to world 3D
    xs, ys, zs, values = volume_to_world_points_and_values(
        vol_values, height_values, camera_pose, cam_K,
        step_xy=args.step_xy, step_h=args.step_h,
    )
    if len(xs) == 0:
        print("No valid 3D points (all behind camera?). Try smaller step_xy.")
        return

    # Normalize values to [0, 1] so the full colormap is used (blue=low, red=high)
    v_min, v_max = values.min(), values.max()
    values_norm = (values - v_min) / (v_max - v_min + 1e-9)
    # Opacity: power law so low-intensity (blue) is transparent; high (red) fully opaque
    alpha_power = 2.5
    alphas = np.power(values_norm, alpha_power)
    alphas[values_norm >= 0.4] = 1.0  # red / high-intensity fully opaque
    # Marker size: larger for high-intensity (red) so they stand out
    sizes = 1+0.5 + 6.0 * values_norm  # ~0.5 (blue) to ~6.5 (red)
    #alphas=np.clip(alphas+.2, 0, 1)

    # Argmax keypoint in 3D for overlay (always from raw logits)
    flat_idx = np.argmax(vol_logits_t)
    h_bin, row, col = np.unravel_index(flat_idx, vol_values.shape)
    height_m_kp = height_values[h_bin]
    pt_3d_kp = recover_3d_from_direct_keypoint_and_height(
        np.array([float(col), float(row)], dtype=np.float64),
        height_m_kp,
        camera_pose,
        cam_K,
    )

    # 3D plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection="3d")
    sc = ax.scatter(
        xs, ys, zs,
        c=values_norm,
        s=sizes,
        alpha=alphas,
        cmap="coolwarm",
        vmin=0,
        vmax=1,
    )
    if pt_3d_kp is not None:
        ax.scatter(
            [pt_3d_kp[0]], [pt_3d_kp[1]], [pt_3d_kp[2]],
            c="lime", s=120, marker="*", edgecolors="black", linewidths=1, label="argmax keypoint",
        )
    ax.set_xlabel("X (m)")
    ax.set_ylabel("Y (m)")
    ax.set_zlabel("Z (m)")
    ax.set_title(f"Volume 3D (timestep {args.timestep}) — color = {'prob' if args.use_softmax else 'logit'} (normalized)")
    plt.colorbar(sc, ax=ax, shrink=0.6, label="intensity (0=min, 1=max)")
    if pt_3d_kp is not None:
        ax.legend()
    plt.tight_layout()
    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    print(f"Saved: {out_path}")
    plt.show()
    plt.close()


if __name__ == "__main__":
    main()
