"""Analyze 2view model's per-view confidence — is there a signal that says 'wrist is OOD'?

For each sample in the validation set:
  1. Compute the BEV-only and wrist-only volume distributions separately (re-derive from the
     same q_F_bev / q_F_wrist that go into the fused model — but each view alone).
  2. Compute per-view metrics:
       entropy_bev, entropy_wrist             — uncertainty over (Z, H, W)
       maxprob_bev, maxprob_wrist             — peak softmax prob
       kl_bev_to_wrist, kl_wrist_to_bev       — view disagreement
       wrist_feat_norm                        — mean L2 of projected wrist features (low = out of frustum)
  3. Compute the fused model's prediction error: distance from argmax to GT EEF pixel.

Then look for correlations between confidence metrics and prediction error. If the wrist's entropy
spikes / maxprob drops in the cases where the model is wrong, we can use that as a gate.

Output: stats.csv + summary plots.
"""
import os, sys, argparse, json
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("DINO_REPO_DIR", "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
os.environ.setdefault("MUJOCO_GL", "osmesa")
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path

from data_libero_2view import CachedTrajectory2ViewDataset
from model_dino_volume_query_2view import (
    DinoVolumeQuery2View, PRED_SIZE, build_bev_world_xyz_table_batched,
    project_world_to_wrist_uv_grid,
)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", required=True)
    p.add_argument("--cache_root", default="/data/libero/ood_viewpoint_v3_splits_2view_qmlp/vp_train")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--max_samples", type=int, default=200)
    p.add_argument("--out_dir", default="/tmp/wrist_conf")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
    n_window = int(ckpt.get("n_window", 8))
    image_size = int(ckpt.get("image_size", 448))
    n_rot_bins = int(ckpt["n_rot_bins"])
    min_h, max_h = float(ckpt["min_height"]), float(ckpt["max_height"])
    fusion_mode = str(ckpt.get("fusion_mode", "sum"))
    print(f"ckpt: epoch={ckpt['epoch']}, fusion={fusion_mode}, n_window={n_window}")

    model = DinoVolumeQuery2View(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=n_rot_bins,
        image_size=image_size, pred_size=PRED_SIZE,
        rotation_mode='1d_pca', fusion_mode=fusion_mode,
    ).to(device).eval()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)

    ds = CachedTrajectory2ViewDataset(
        cache_root=args.cache_root, task_ids=[args.task_id],
        image_size=image_size, n_window=n_window, frame_stride=3, max_demos=0,
    )
    print(f"Dataset: {len(ds)} samples")

    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    rng = np.random.RandomState(42)
    sample_idxs = rng.choice(len(ds), min(args.max_samples, len(ds)), replace=False)

    rows = []
    for k, sidx in enumerate(sample_idxs):
        s = ds[int(sidx)]
        with torch.no_grad():
            rgb_bev = s["rgb_bev"].unsqueeze(0).to(device)
            rgb_wrist = s["rgb_wrist"].unsqueeze(0).to(device)
            start_pix = s["trajectory_2d_bev"][0:1].to(device)
            wrist_K = s["wrist_K_norm"].unsqueeze(0).to(device)
            wrist_ext = s["wrist_extrinsic"].unsqueeze(0).to(device)
            bev_K = s["bev_K_norm"].unsqueeze(0).to(device)
            bev_ext = s["bev_extrinsic"].unsqueeze(0).to(device)
            bev_xyz = build_bev_world_xyz_table_batched(
                bev_K, bev_ext, 32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size,
            )
            out = model(rgb_bev, rgb_wrist, start_pix, bev_xyz, wrist_K, wrist_ext)
            vol_fused = out["volume_logits"][0]                                      # (T, Z, H, W)

            # To get per-view distributions, redo the einsums explicitly (mirror model internals).
            # Use the public pixel_feats / pixel_feats_wrist + the query stream.
            # Re-run model with hooks to capture intermediates — easier: just access vol_fused
            # and try to back out per-view contributions by zeroing one view.
            # Trick: set wrist_K to a degenerate value so the projection produces all out-of-frustum.
            # Then volume = bev-only.
            # Cleaner: re-call the model with wrist_extrinsic shifted way out so all projects fail.

            # Simpler: compute per-view scores by hand using model's internal weights.
            # The model exposes pixel_feats and pixel_feats_wrist after refinement.
            # The query stream (q_spatial) is internal — capture it via a forward hook.
            pass

        # Re-run with hook to grab q_spatial
        q_holder = {}
        def hook(_mod, _in, out_tensor):
            # out is (B, T, d_F*2 + d_z + d_t) from the spatial head
            q_holder['q_spatial'] = out_tensor.detach()
        h = model.q_head.register_forward_hook(hook)
        with torch.no_grad():
            _ = model(rgb_bev, rgb_wrist, start_pix, bev_xyz, wrist_K, wrist_ext)
        h.remove()

        q_spatial = q_holder['q_spatial'][0]  # (T, dd)
        d_F = model.d_feat; d_z = model.d_sin_z; d_t = model.d_sin_t
        q_F_bev = q_spatial[..., :d_F]
        q_F_wrist = q_spatial[..., d_F:2*d_F]
        q_z = q_spatial[..., 2*d_F:2*d_F+d_z]
        q_t = q_spatial[..., 2*d_F+d_z:]

        F_bev = out["pixel_feats"][0]      # (d, H, W)
        F_wrist = out["pixel_feats_wrist"][0]  # (d, H, W)

        # Get F_w_sampled at all (Z, H, W) voxels by re-running projection
        with torch.no_grad():
            uv_grid = project_world_to_wrist_uv_grid(bev_xyz, wrist_K, wrist_ext, image_size)
            Bv, Zv, Hv, Wv, _ = uv_grid.shape
            grid_flat = uv_grid.view(Bv, Zv * Hv, Wv, 2)
            F_w_sampled = torch.nn.functional.grid_sample(
                F_wrist.unsqueeze(0), grid_flat, mode='bilinear', padding_mode='zeros', align_corners=True
            )
            F_w_sampled = F_w_sampled.view(Bv, F_wrist.shape[0], Zv, Hv, Wv)[0]  # (d, Z, H, W)

        # Compute per-view scores
        score_bev_yx = torch.einsum('tc,chw->thw', q_F_bev, F_bev)              # (T, H, W)
        score_wrist_zyx = torch.einsum('tc,czhw->tzhw', q_F_wrist, F_w_sampled) # (T, Z, H, W)
        score_z = torch.einsum('tc,zc->tz', q_z, model.z_sin)                  # (T, Z)
        score_t = torch.einsum('tc,tc->t', q_t, model.t_sin)                   # (T,)
        z_term = score_z.unsqueeze(-1).unsqueeze(-1)                            # (T, Z, 1, 1)
        t_term = score_t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)              # (T, 1, 1, 1)

        vol_bev_only   = score_bev_yx.unsqueeze(1) + z_term + t_term            # (T, Z, H, W)
        vol_wrist_only = score_wrist_zyx + z_term + t_term                      # (T, Z, H, W)

        T_, Z_, H_, W_ = vol_bev_only.shape
        p_bev   = vol_bev_only.reshape(T_, -1).softmax(-1)                       # (T, Z*H*W)
        p_wrist = vol_wrist_only.reshape(T_, -1).softmax(-1)
        p_fused = vol_fused.reshape(T_, -1).softmax(-1)

        ent_bev   = -(p_bev   * (p_bev.clamp_min(1e-12)).log()).sum(-1)          # (T,)
        ent_wrist = -(p_wrist * (p_wrist.clamp_min(1e-12)).log()).sum(-1)
        ent_fused = -(p_fused * (p_fused.clamp_min(1e-12)).log()).sum(-1)
        maxprob_bev   = p_bev.max(-1)[0]
        maxprob_wrist = p_wrist.max(-1)[0]
        maxprob_fused = p_fused.max(-1)[0]
        # KL(BEV || wrist)
        kl_b_to_w = (p_bev * ((p_bev.clamp_min(1e-12) / p_wrist.clamp_min(1e-12)).log())).sum(-1)
        kl_w_to_b = (p_wrist * ((p_wrist.clamp_min(1e-12) / p_bev.clamp_min(1e-12)).log())).sum(-1)
        # Wrist feature norm — average over voxels (high = lots of in-frustum signal)
        wrist_feat_norm_per_voxel = F_w_sampled.pow(2).sum(0).sqrt()             # (Z, H, W)
        wrist_feat_norm_mean = wrist_feat_norm_per_voxel.mean()
        # Fraction of voxels that are in-frustum (non-zero feature)
        in_frustum_frac = (wrist_feat_norm_per_voxel > 1e-4).float().mean()

        # Pred error per timestep — argmax voxel → (Z, H, W) → BEV pixel — compare to GT pix
        flat = vol_fused.reshape(T_, -1).argmax(-1)
        z_b = flat // (H_ * W_)
        yx = flat % (H_ * W_)
        py_g = yx // W_; px_g = yx % W_
        scale = image_size / H_
        pred_pix = torch.stack([px_g.float() * scale + scale/2, py_g.float() * scale + scale/2], dim=-1)  # (T, 2)
        gt_pix = s["trajectory_2d_bev"].to(device)                              # (T, 2)
        pix_err = (pred_pix - gt_pix).norm(dim=-1)                              # (T,)

        # Aggregate per-sample (mean over T)
        rows.append({
            "sample_idx": int(sidx),
            "ent_bev":   float(ent_bev.mean().cpu()),
            "ent_wrist": float(ent_wrist.mean().cpu()),
            "ent_fused": float(ent_fused.mean().cpu()),
            "mp_bev":    float(maxprob_bev.mean().cpu()),
            "mp_wrist":  float(maxprob_wrist.mean().cpu()),
            "mp_fused":  float(maxprob_fused.mean().cpu()),
            "kl_b_to_w": float(kl_b_to_w.mean().cpu()),
            "kl_w_to_b": float(kl_w_to_b.mean().cpu()),
            "wrist_feat_norm": float(wrist_feat_norm_mean.cpu()),
            "in_frustum_frac": float(in_frustum_frac.cpu()),
            "pix_err_mean": float(pix_err.mean().cpu()),
            "pix_err_max":  float(pix_err.max().cpu()),
        })
        if (k+1) % 25 == 0: print(f"  [{k+1}/{len(sample_idxs)}]")

    # Save raw
    import csv
    with open(out_dir / "stats.csv", "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader(); w.writerows(rows)

    # Correlation analysis: how does each confidence metric relate to pix_err_mean?
    pix_err = np.array([r["pix_err_mean"] for r in rows])
    print(f"\nPixel error stats: mean={pix_err.mean():.1f} med={np.median(pix_err):.1f} 90%={np.percentile(pix_err, 90):.1f}")

    print("\nCorrelation with pix_err_mean (Pearson):")
    for key in ["ent_bev","ent_wrist","mp_bev","mp_wrist","kl_b_to_w","kl_w_to_b","wrist_feat_norm","in_frustum_frac"]:
        vals = np.array([r[key] for r in rows])
        if vals.std() < 1e-9: continue
        corr = np.corrcoef(vals, pix_err)[0,1]
        print(f"  {key:24s} corr={corr:+.3f}  range=[{vals.min():.3f}, {vals.max():.3f}]  mean={vals.mean():.3f}")

    # Plot: each metric vs pix_err scatter
    fig, axes = plt.subplots(2, 4, figsize=(20, 9))
    for ax, key in zip(axes.flat, ["ent_bev","ent_wrist","mp_bev","mp_wrist","kl_b_to_w","kl_w_to_b","wrist_feat_norm","in_frustum_frac"]):
        vals = np.array([r[key] for r in rows])
        ax.scatter(vals, pix_err, s=8, alpha=0.6)
        ax.set_xlabel(key); ax.set_ylabel("pix_err_mean (px)")
        if vals.std() > 1e-9:
            corr = np.corrcoef(vals, pix_err)[0,1]
            ax.set_title(f"{key}  corr={corr:+.3f}")
        else:
            ax.set_title(f"{key}  (constant)")
    plt.tight_layout()
    plt.savefig(out_dir / "scatter.png", dpi=100, bbox_inches='tight')
    print(f"\nSaved → {out_dir}/stats.csv and scatter.png")


if __name__ == "__main__":
    main()
