"""
Visualize the 3D volume from a saved inference as a point cloud in PyVista.
Supports per-point opacity (low intensity -> transparent, high -> opaque).
Uses the same inference tensor (volume_logits, camera_pose, cam_K, minmax_height).
"""
import argparse

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import pyvista as pv

import os
import sys
sys.path.insert(0, os.path.dirname(__file__))

from utils import recover_3d_from_direct_keypoint_and_height

IMAGE_SIZE = 448
INFERENCE_PATH = "volume_dino_tracks/scratch/vistest_live_test_model_ik_data.pth"


def load_inference(path: str):
    """Load saved inference: rgb, camera_pose, cam_K, pred_logits (volume), minmax_height."""
    data = torch.load(path, map_location="cpu", weights_only=False)
    rgb = data["rgb"]
    if isinstance(rgb, torch.Tensor):
        rgb = rgb.numpy()
    camera_pose = data["camera_pose"]
    if isinstance(camera_pose, torch.Tensor):
        camera_pose = camera_pose.numpy()
    cam_K = data["cam_K"]
    if isinstance(cam_K, torch.Tensor):
        cam_K = cam_K.numpy()
    volume_logits = data.get("pred_logits", data.get("volume_logits"))
    if volume_logits is None:
        raise KeyError("Inference file must contain 'pred_logits' or 'volume_logits'")
    if isinstance(volume_logits, torch.Tensor):
        volume_logits = volume_logits.cpu().numpy()
    minmax_height = data.get("minmax_height", (0.03, 0.17))
    if isinstance(minmax_height, (list, tuple)):
        minmax_height = tuple(float(x) for x in minmax_height)
    return {
        "rgb": rgb,
        "camera_pose": camera_pose,
        "cam_K": cam_K,
        "volume_logits": volume_logits,
        "minmax_height": minmax_height,
    }


def softmax_volume(volume_logits, timestep=-1):
    """Softmax over the full volume at one timestep."""
    vol = volume_logits[0, timestep]
    vol = np.asarray(vol, dtype=np.float64)
    vol_flat = vol.ravel()
    vol_flat = vol_flat - vol_flat.max()
    exp_vol = np.exp(vol_flat)
    probs_flat = exp_vol / exp_vol.sum()
    return probs_flat.reshape(vol.shape).astype(np.float32)


def volume_to_world_points_and_values(
    vol_values,
    height_values,
    camera_pose,
    cam_K,
    step_xy=7,
    step_h=1,
):
    """Sample volume on a grid; convert each (col, row, height) to world 3D."""
    n_h, H, W = vol_values.shape
    xs, ys, zs, values = [], [], [], []
    for row in range(0, H, step_xy):
        for col in range(0, W, step_xy):
            for h_idx in range(0, n_h, step_h):
                v = float(vol_values[h_idx, row, col])
                height_m = float(height_values[h_idx])
                pt_3d = recover_3d_from_direct_keypoint_and_height(
                    np.array([float(col), float(row)], dtype=np.float64),
                    height_m,
                    camera_pose,
                    cam_K,
                )
                if pt_3d is None:
                    continue
                xs.append(pt_3d[0])
                ys.append(pt_3d[1])
                zs.append(pt_3d[2])
                values.append(v)
    return np.array(xs), np.array(ys), np.array(zs), np.array(values)


def values_norm_to_rgb(values_norm, cmap_name="coolwarm"):
    """Map normalized [0,1] values to RGB (N, 3) in [0,1] using matplotlib colormap."""
    cmap = plt.get_cmap(cmap_name)
    rgb = cmap(np.clip(values_norm, 0.0, 1.0).astype(np.float32))[:, :3].astype(np.float32)
    return rgb


def _alpha_by_intensity(values_norm, power=2.5, min_alpha=0.05):
    """Per-point opacity: low intensity -> transparent, high -> opaque."""
    return np.clip(min_alpha + (1.0 - min_alpha) * np.power(values_norm, power), 0.0, 1.0)


def image_plane_quad_world(camera_pose, cam_K, width, height, depth=0.3):
    """Return 4 corners of the image plane in world (robot) coordinates, at given depth.
    camera_pose: 4x4 world -> camera. cam_K: 3x3 intrinsics (fx, fy, cx, cy).
    Corners order: (0,0), (W,0), (W,H), (0,H) in pixel coords -> top-left, top-right, bottom-right, bottom-left.
    """
    fx, fy = cam_K[0, 0], cam_K[1, 1]
    cx, cy = cam_K[0, 2], cam_K[1, 2]
    cam_pose_inv = np.linalg.inv(np.asarray(camera_pose, dtype=np.float64))
    corners_pixel = [(0, 0), (width, 0), (width, height), (0, height)]
    corners_world = []
    for u, v in corners_pixel:
        x_c = (u - cx) * depth / fx
        y_c = (v - cy) * depth / fy
        z_c = depth
        p_cam = np.array([x_c, y_c, z_c, 1.0], dtype=np.float64)
        p_world = cam_pose_inv @ p_cam
        corners_world.append(p_world[:3])
    return np.array(corners_world, dtype=np.float64)


def image_quad_mesh(corners_world, rgb_image):
    """Build a PyVista quad mesh (2 triangles) with texture coords and texture from RGB image.
    corners_world: (4, 3) in order top-left, top-right, bottom-right, bottom-left.
    rgb_image: (H, W, 3) uint8 or float.
    """
    # Two triangles: [0,1,2] and [0,2,3]
    faces = np.array([3, 0, 1, 2, 3, 0, 2, 3], dtype=np.int32)
    quad = pv.PolyData(corners_world, faces)
    # Texture coords: (0,0), (1,0), (1,1), (0,1) for top-left, top-right, bottom-right, bottom-left
    # PyVista expects texture coords in point_data; use (n, 2) and set as active
    tcoords = np.array([[0, 0], [1, 0], [1, 1], [0, 1]], dtype=np.float32)
    quad.point_data.set_array(tcoords, "Texture Coordinates")
    quad.point_data.active_texture_coordinates_name = "Texture Coordinates"
    img = np.asarray(rgb_image)
    if img.dtype != np.uint8:
        img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
    tex = pv.Texture(img)
    return quad, tex


def main():
    parser = argparse.ArgumentParser(description="Visualize inference volume as 3D point cloud in PyVista")
    parser.add_argument("--inference", type=str, default=INFERENCE_PATH, help="Path to saved inference .pth")
    parser.add_argument("--timestep", type=int, default=-1, help="Volume timestep to visualize (-1 = last)")
    parser.add_argument("--step_xy", type=int, default=13, help="Spatial downsampling (pixels)")
    parser.add_argument("--step_h", type=int, default=1, help="Height bin step (1 = all bins)")
    parser.add_argument("--use_softmax", action="store_true", help="Use softmax over volume; default is raw logits")
    parser.add_argument("--gamma", type=float, default=0.4, help="Power scaling when using softmax")
    parser.add_argument("--point_size", type=float, default=8.0, help="Point size in PyVista (world units or pixels)")
    parser.add_argument("--sphere_radius", type=float, default=0.012, help="Argmax keypoint sphere radius (meters)")
    parser.add_argument("--opacity", type=float, default=1.0, help="Uniform opacity in [0,1] (if not using per-point)")
    parser.add_argument("--opacity_by_intensity", action="store_true", help="Per-point opacity: low intensity more transparent")
    parser.add_argument("--opacity_power", type=float, default=2.5, help="Power for alpha = values_norm^power")
    parser.add_argument("--opacity_transfer", type=str, default=None, choices=["linear", "linear_r", "geom", "geom_r", "sigmoid", "sigmoid_r"],
                        help="Opacity transfer function (scalar -> opacity) instead of uniform or per-point")
    parser.add_argument("--image_depth", type=float, default=0.3, help="Depth (m) of input image plane in camera space for alignment with point cloud")
    parser.add_argument("--no_image", action="store_true", help="Do not plot the input image quad")
    args = parser.parse_args()

    inf = load_inference(args.inference)
    H1_rgb, W1_rgb = inf["rgb"].shape[:2]
    camera_pose = inf["camera_pose"]
    cam_K_rgb = inf["cam_K"]
    cam_K = np.eye(3, dtype=np.float64)
    cam_K[0, 0] = cam_K_rgb[0, 0] * (IMAGE_SIZE / W1_rgb)
    cam_K[0, 2] = cam_K_rgb[0, 2] * (IMAGE_SIZE / W1_rgb)
    cam_K[1, 1] = cam_K_rgb[1, 1] * (IMAGE_SIZE / H1_rgb)
    cam_K[1, 2] = cam_K_rgb[1, 2] * (IMAGE_SIZE / H1_rgb)

    min_height_m = float(inf["minmax_height"][0])
    max_height_m = float(inf["minmax_height"][1])
    n_height_bins = inf["volume_logits"].shape[2]
    height_values = np.linspace(min_height_m, max_height_m, n_height_bins, dtype=np.float64)

    vol_logits_t = inf["volume_logits"][0, args.timestep]
    if hasattr(vol_logits_t, "numpy"):
        vol_logits_t = vol_logits_t.numpy()
    vol_logits_t = np.asarray(vol_logits_t, dtype=np.float64)
    if args.use_softmax:
        vol_values = softmax_volume(inf["volume_logits"], timestep=args.timestep)
        vol_max = float(vol_values.max())
        if vol_max > 0:
            vol_values = vol_values / vol_max
        vol_values = np.power(np.clip(vol_values, 0, None), args.gamma)
    else:
        vol_values = vol_logits_t

    xs, ys, zs, values = volume_to_world_points_and_values(
        vol_values, height_values, camera_pose, cam_K,
        step_xy=args.step_xy, step_h=args.step_h,
    )
    if len(xs) == 0:
        print("No valid 3D points (all behind camera?). Try smaller step_xy.")
        return

    v_min, v_max = values.min(), values.max()
    values_norm = (values - v_min) / (v_max - v_min + 1e-9)
    colors_rgb = values_norm_to_rgb(values_norm)

    flat_idx = np.argmax(vol_logits_t)
    h_bin, row, col = np.unravel_index(flat_idx, vol_values.shape)
    height_m_kp = height_values[h_bin]
    pt_3d_kp = recover_3d_from_direct_keypoint_and_height(
        np.array([float(col), float(row)], dtype=np.float64),
        height_m_kp,
        camera_pose,
        cam_K,
    )

    points = np.stack([xs, ys, zs], axis=1)
    mesh = pv.PolyData(points)
    mesh["values"] = values_norm

    if args.opacity_by_intensity:
        # Per-point opacity: PyVista opacity array = transparency when use_transparency=True (high value = more transparent)
        alpha = _alpha_by_intensity(values_norm, power=args.opacity_power)
        mesh["opacity"] = alpha  # we want high alpha = opaque; PyVista uses this as opacity (high = opaque)
        opacity_kw = {"opacity": "opacity", "use_transparency": False}
    elif args.opacity_transfer:
        opacity_kw = {"opacity": args.opacity_transfer}
    else:
        opacity_kw = {"opacity": args.opacity}

    pl = pv.Plotter()
    pl.add_mesh(
        mesh,
        scalars="values",
        cmap="coolwarm",
        render_points_as_spheres=True,
        point_size=args.point_size,
        show_scalar_bar=True,
        scalar_bar_args={"title": "intensity"},
        **opacity_kw,
    )

    if pt_3d_kp is not None:
        sphere = pv.Sphere(center=pt_3d_kp, radius=args.sphere_radius)
        pl.add_mesh(sphere, color="lime", opacity=1.0)

    if not args.no_image:
        # Place image plane in front of all volume points so no rays/points appear "through" the image.
        # camera_pose is world -> camera; depth in camera space = (camera_pose @ p_world)[2]
        points_h = np.hstack([points, np.ones((len(points), 1), dtype=np.float64)])
        points_cam = (np.asarray(camera_pose, dtype=np.float64) @ points_h.T).T[:, :3]
        min_depth = float(np.min(points_cam[:, 2]))
        image_depth_used = min(args.image_depth, min_depth - 1e-3)
        if image_depth_used < 1e-3:
            image_depth_used = args.image_depth  # fallback if all points very close
        corners_world = image_plane_quad_world(
            camera_pose, cam_K_rgb, W1_rgb, H1_rgb, depth=image_depth_used
        )
        quad, tex = image_quad_mesh(corners_world, inf["rgb"])
        pl.add_mesh(quad, texture=tex, opacity=0.9)
        print(f"✓ Input image quad at depth {image_depth_used:.3f}m (min volume depth {min_depth:.3f}m)")

    pl.background_color = "gray"
    print(f"✓ Volume point cloud: {len(points)} points (timestep {args.timestep})")
    if args.opacity_by_intensity:
        print("✓ Per-point opacity by intensity enabled")
    print("✓ PyVista window: close the window to exit.")
    pl.show()


if __name__ == "__main__":
    main()
