"""
Visualize the 3D volume from a saved inference as a point cloud in Open3D.
Uses the same inference tensor (volume_logits, camera_pose, cam_K, minmax_height);
no second view or heatmap warping.
"""
import argparse

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import open3d as o3d

import os
import sys
sys.path.insert(0, os.path.dirname(__file__))

from utils import recover_3d_from_direct_keypoint_and_height

IMAGE_SIZE = 448
N_HEIGHT_BINS = 32

INFERENCE_PATH = "volume_dino_tracks/scratch/vistest_live_test_model_ik_data.pth"


def load_inference(path: str):
    """Load saved inference: rgb, camera_pose, cam_K, pred_logits (volume), minmax_height."""
    data = torch.load(path, map_location="cpu", weights_only=False)
    rgb = data["rgb"]
    if isinstance(rgb, torch.Tensor):
        rgb = rgb.numpy()
    camera_pose = data["camera_pose"]
    if isinstance(camera_pose, torch.Tensor):
        camera_pose = camera_pose.numpy()
    cam_K = data["cam_K"]
    if isinstance(cam_K, torch.Tensor):
        cam_K = cam_K.numpy()
    volume_logits = data.get("pred_logits", data.get("volume_logits"))
    if volume_logits is None:
        raise KeyError("Inference file must contain 'pred_logits' or 'volume_logits'")
    if isinstance(volume_logits, torch.Tensor):
        volume_logits = volume_logits.cpu().numpy()
    minmax_height = data.get("minmax_height", (0.03, 0.17))
    if isinstance(minmax_height, (list, tuple)):
        minmax_height = tuple(float(x) for x in minmax_height)
    return {
        "rgb": rgb,
        "camera_pose": camera_pose,
        "cam_K": cam_K,
        "volume_logits": volume_logits,
        "minmax_height": minmax_height,
    }


def softmax_volume(volume_logits, timestep=-1):
    """Softmax over the full volume at one timestep.
    volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, H, W).
    Returns (N_HEIGHT_BINS, H, W) float, probabilities summing to 1 over the volume.
    """
    vol = volume_logits[0, timestep]  # (N_HEIGHT_BINS, H, W)
    vol = np.asarray(vol, dtype=np.float64)
    vol_flat = vol.ravel()
    vol_flat = vol_flat - vol_flat.max()
    exp_vol = np.exp(vol_flat)
    probs_flat = exp_vol / exp_vol.sum()
    return probs_flat.reshape(vol.shape).astype(np.float32)


def volume_to_world_points_and_values(
    vol_values,  # (N_HEIGHT_BINS, H, W) — logits or probs
    height_values,  # (N_HEIGHT_BINS,) in meters
    camera_pose,
    cam_K,
    step_xy=7,
    step_h=1,
):
    """Sample volume on a grid; convert each (col, row, height) to world 3D.
    Returns xs, ys, zs (world coords), values (same length).
    """
    n_h, H, W = vol_values.shape
    xs, ys, zs, values = [], [], [], []
    for row in range(0, H, step_xy):
        for col in range(0, W, step_xy):
            for h_idx in range(0, n_h, step_h):
                v = float(vol_values[h_idx, row, col])
                height_m = float(height_values[h_idx])
                pt_3d = recover_3d_from_direct_keypoint_and_height(
                    np.array([float(col), float(row)], dtype=np.float64),
                    height_m,
                    camera_pose,
                    cam_K,
                )
                if pt_3d is None:
                    continue
                xs.append(pt_3d[0])
                ys.append(pt_3d[1])
                zs.append(pt_3d[2])
                values.append(v)
    return np.array(xs), np.array(ys), np.array(zs), np.array(values)


def values_norm_to_rgb(values_norm, cmap_name="coolwarm"):
    """Map normalized [0,1] values to RGB (N, 3) in [0,1] using matplotlib colormap."""
    cmap = plt.get_cmap(cmap_name)
    rgb = cmap(np.clip(values_norm, 0.0, 1.0).astype(np.float32))[:, :3].astype(np.float32)
    return rgb


def _alpha_by_intensity(values_norm, power=2.5, min_alpha=0.05):
    """Per-point opacity: low intensity -> transparent, high -> opaque."""
    return np.clip(min_alpha + (1.0 - min_alpha) * np.power(values_norm, power), 0.0, 1.0)


def main():
    parser = argparse.ArgumentParser(description="Visualize inference volume as 3D point cloud in Open3D")
    parser.add_argument("--inference", type=str, default=INFERENCE_PATH, help="Path to saved inference .pth")
    parser.add_argument("--timestep", type=int, default=-1, help="Volume timestep to visualize (-1 = last)")
    parser.add_argument("--step_xy", type=int, default=13, help="Spatial downsampling (pixels)")
    parser.add_argument("--step_h", type=int, default=1, help="Height bin step (1 = all bins)")
    parser.add_argument("--use_softmax", action="store_true", help="Use softmax over volume; default is raw logits")
    parser.add_argument("--gamma", type=float, default=0.4, help="Power scaling when using softmax")
    parser.add_argument("--point_size", type=float, default=2.0, help="Open3D point size (pixels)")
    parser.add_argument("--sphere_radius", type=float, default=0.012, help="Argmax keypoint sphere radius (meters)")
    parser.add_argument("--opacity", type=float, default=1.0, help="Uniform point cloud opacity in [0,1] (uses draw() with transparency)")
    parser.add_argument("--opacity_by_intensity", action="store_true", help="Per-point opacity: low intensity more transparent (splits into bins)")
    parser.add_argument("--opacity_bins", type=int, default=8, help="Number of opacity bins when using --opacity_by_intensity")
    parser.add_argument("--opacity_power", type=float, default=2.5, help="Power for alpha = values_norm^power when using --opacity_by_intensity")
    args = parser.parse_args()

    inf = load_inference(args.inference)
    rgb = inf["rgb"]
    H1_rgb, W1_rgb = rgb.shape[:2]
    camera_pose = inf["camera_pose"]
    cam_K_rgb = inf["cam_K"]
    cam_K = np.eye(3, dtype=np.float64)
    cam_K[0, 0] = cam_K_rgb[0, 0] * (IMAGE_SIZE / W1_rgb)
    cam_K[0, 2] = cam_K_rgb[0, 2] * (IMAGE_SIZE / W1_rgb)
    cam_K[1, 1] = cam_K_rgb[1, 1] * (IMAGE_SIZE / H1_rgb)
    cam_K[1, 2] = cam_K_rgb[1, 2] * (IMAGE_SIZE / H1_rgb)

    min_height_m = float(inf["minmax_height"][0])
    max_height_m = float(inf["minmax_height"][1])
    n_height_bins = inf["volume_logits"].shape[2]
    height_values = np.linspace(min_height_m, max_height_m, n_height_bins, dtype=np.float64)

    # Volume at chosen timestep: raw logits (default) or softmax
    vol_logits_t = inf["volume_logits"][0, args.timestep]
    if hasattr(vol_logits_t, "numpy"):
        vol_logits_t = vol_logits_t.numpy()
    vol_logits_t = np.asarray(vol_logits_t, dtype=np.float64)
    if args.use_softmax:
        vol_values = softmax_volume(inf["volume_logits"], timestep=args.timestep)
        vol_max = float(vol_values.max())
        if vol_max > 0:
            vol_values = vol_values / vol_max
        vol_values = np.power(np.clip(vol_values, 0, None), args.gamma)
    else:
        vol_values = vol_logits_t

    # Sample and convert to world 3D
    xs, ys, zs, values = volume_to_world_points_and_values(
        vol_values, height_values, camera_pose, cam_K,
        step_xy=args.step_xy, step_h=args.step_h,
    )
    if len(xs) == 0:
        print("No valid 3D points (all behind camera?). Try smaller step_xy.")
        return

    # Normalize values to [0, 1] for color (blue=low, red=high)
    v_min, v_max = values.min(), values.max()
    values_norm = (values - v_min) / (v_max - v_min + 1e-9)
    colors_rgb = values_norm_to_rgb(values_norm)

    # Argmax keypoint in 3D (always from raw logits)
    flat_idx = np.argmax(vol_logits_t)
    h_bin, row, col = np.unravel_index(flat_idx, vol_values.shape)
    height_m_kp = height_values[h_bin]
    pt_3d_kp = recover_3d_from_direct_keypoint_and_height(
        np.array([float(col), float(row)], dtype=np.float64),
        height_m_kp,
        camera_pose,
        cam_K,
    )

    points = np.stack([xs, ys, zs], axis=1).astype(np.float32)
    use_transparency = args.opacity < 1.0 or args.opacity_by_intensity

    if use_transparency:
        # Use draw() with MaterialRecord for opacity (requires tensor PointCloud)
        mat_base = o3d.visualization.rendering.MaterialRecord()
        mat_base.shader = "defaultLitTransparency"
        mat_base.point_size = args.point_size

        if args.opacity_by_intensity:
            # Split by values_norm into bins; each bin = one point cloud with its alpha
            bins = np.linspace(0, 1, args.opacity_bins + 1)
            draw_list = []
            for k in range(args.opacity_bins):
                lo, hi = bins[k], bins[k + 1]
                mask = (values_norm >= lo) & (values_norm < hi) if k < args.opacity_bins - 1 else (values_norm >= lo) & (values_norm <= hi)
                if not np.any(mask):
                    continue
                pts_k = points[mask]
                cols_k = colors_rgb[mask]
                alpha = float(_alpha_by_intensity(np.array([values_norm[mask].mean()]), power=args.opacity_power)[0])
                pcd_t = o3d.t.geometry.PointCloud(o3d.core.Tensor(pts_k))
                pcd_t.point["colors"] = o3d.core.Tensor(cols_k.astype(np.float32))
                mat = o3d.visualization.rendering.MaterialRecord()
                mat.shader = "defaultLitTransparency"
                mat.base_color = [1.0, 1.0, 1.0, alpha]
                mat.point_size = args.point_size
                draw_list.append({"name": f"volume_bin_{k}", "geometry": pcd_t, "material": mat})
            if not draw_list:
                draw_list = [{"name": "volume", "geometry": o3d.t.geometry.PointCloud(o3d.core.Tensor(points)), "material": mat_base}]
                draw_list[0]["material"].base_color = [1.0, 1.0, 1.0, args.opacity]
        else:
            pcd_t = o3d.t.geometry.PointCloud(o3d.core.Tensor(points))
            pcd_t.point["colors"] = o3d.core.Tensor(colors_rgb.astype(np.float32))
            mat_base.base_color = [1.0, 1.0, 1.0, args.opacity]
            draw_list = [{"name": "volume", "geometry": pcd_t, "material": mat_base}]

        if pt_3d_kp is not None:
            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=args.sphere_radius)
            sphere.translate(pt_3d_kp)
            sphere.paint_uniform_color([0.0, 1.0, 0.0])
            draw_list.append({"name": "argmax_keypoint", "geometry": sphere})

        print(f"✓ Volume point cloud: {len(points)} points (timestep {args.timestep}), opacity enabled")
        print("✓ Open3D window: close the window to exit.")
        o3d.visualization.draw(draw_list, title="Volume", width=1024, height=768, bg_color=(0.1, 0.1, 0.1, 1.0), show_skybox=False)
    else:
        # Legacy Visualizer (opaque only)
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
        pcd.colors = o3d.utility.Vector3dVector(colors_rgb.astype(np.float64))
        geometries = [pcd]
        if pt_3d_kp is not None:
            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=args.sphere_radius)
            sphere.translate(pt_3d_kp)
            sphere.paint_uniform_color([0.0, 1.0, 0.0])
            geometries.append(sphere)
        print(f"✓ Volume point cloud: {len(points)} points (timestep {args.timestep})")
        print("✓ Open3D window: close the window to exit.")
        vis = o3d.visualization.Visualizer()
        vis.create_window(window_name="Volume", width=1024, height=768)
        for g in geometries:
            vis.add_geometry(g)
        vis.get_render_option().point_size = args.point_size
        vis.run()
        vis.destroy_window()


if __name__ == "__main__":
    main()
