#!/usr/bin/env python3
"""gen_distribution_viz.py — Polar plot + sample frames for train vs test."""
import cv2, numpy as np, os, sys
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("LIBERO_DATA_PATH", "/data/libero")
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as ScipyR
from libero.libero import benchmark as bm_lib, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import h5py

EXPERIMENT_NAME = "act_baseline"
n_views = 8
thetas = np.linspace(0, 25, n_views)
phis = np.linspace(0, 360*(1-1/n_views), n_views)

out_dir = f"results/{EXPERIMENT_NAME}"
os.makedirs(out_dir, exist_ok=True)

# --- Polar plot ---
train_vis = [vi for vi in range(64) if (vi // n_views) == 0]  # theta=0
test_vis = [vi for vi in range(64) if (vi // n_views) != 0]

fig, ax = plt.subplots(1, 1, figsize=(4, 4), subplot_kw=dict(projection='polar'))
fig.patch.set_facecolor('#1a1a1a')
ax.set_facecolor('#1a1a1a')
tl = fl = False
for vi in range(64):
    ti, pi = vi // n_views, vi % n_views
    is_train = vi in train_vis
    color = '#66ff66' if is_train else '#6496ff'
    label = None
    if is_train and not tl: label = f'Train (n={len(train_vis)})'; tl = True
    elif not is_train and not fl: label = f'Test (n={len(test_vis)})'; fl = True
    ax.scatter(np.radians(phis[pi]), thetas[ti], c=color, s=100,
              edgecolors='white', linewidths=0.5, label=label)
ax.set_ylim(0, 30)
ax.set_yticks([0, 5, 10, 15, 20, 25])
ax.set_yticklabels(['0°','5°','10°','15°','20°','25°'], fontsize=8, color='white')
ax.set_title('Default → All Viewpoints', color='white', fontsize=11, pad=15)
ax.tick_params(colors='white')
ax.grid(True, alpha=0.3, color='gray')
ax.legend(fontsize=7, facecolor='#2a2a2a', edgecolor='gray', labelcolor='white')
plt.tight_layout()
fig.savefig(f"{out_dir}/polar_plot.png", dpi=150,
            bbox_inches='tight', facecolor='#1a1a1a')
plt.close()
print("Saved polar plot")

# --- Sample frames from sim at train/test viewpoints ---
bench = bm_lib.get_benchmark_dict()["libero_spatial"]()
task = bench.get_task(0)
demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
with h5py.File(demo_path, "r") as f:
    init_state = f["data/demo_0/states"][0]

env = OffScreenRenderEnv(bddl_file_name=bddl_file, camera_heights=448, camera_widths=448,
                          camera_names=["agentview"])
env.seed(0); env.reset(); env.env.horizon = 100000
sim = env.env.sim

# Clean scene
for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
    sim.model.body_pos[sim.model.body_name2id(name)] = np.array([0, 0, -5.0])
for dn in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
    bid = sim.model.body_name2id(dn)
    for gid in range(sim.model.ngeom):
        if sim.model.geom_bodyid[gid] == bid:
            sim.model.geom_rgba[gid][3] = 0.0
sim.forward()

cam_id = sim.model.camera_name2id("agentview")
default_pos = sim.data.cam_xpos[cam_id].copy()
cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
fwd = -cam_xmat[:, 2]
TABLE_Z = 0.90
t_hit = (TABLE_Z - default_pos[2]) / (fwd[2] + 1e-8)
look_at = default_pos + t_hit * fwd
radius = np.linalg.norm(default_pos - look_at)
default_dir = (default_pos - look_at) / radius
up = np.array([0, 0, 1.0])
right = np.cross(default_dir, up); right /= np.linalg.norm(right)
true_up = np.cross(right, default_dir)

bowl_i, plate_i = 10, 38
center_dx = -init_state[bowl_i]
center_dy = -init_state[bowl_i + 1]
DISTRACTOR_POS = np.array([10.0, 10.0, 0.9])
rng = np.random.RandomState(55)
thumb = 120

def capture(theta_deg, phi_deg, dx, dy):
    s = init_state.copy()
    s[bowl_i] += center_dx + dx; s[bowl_i+1] += center_dy + dy
    s[plate_i] += center_dx + dx; s[plate_i+1] += center_dy + dy
    for qps in [17, 24, 31]: s[qps:qps+3] = DISTRACTOR_POS
    env.set_init_state(s); sim.forward()
    env.env.timestep = 0; env.env.done = False
    theta, phi = np.radians(theta_deg), np.radians(phi_deg)
    offset = (np.sin(theta)*np.cos(phi)*right + np.sin(theta)*np.sin(phi)*true_up + np.cos(theta)*default_dir)
    new_pos = look_at + radius * offset
    f = look_at - new_pos; f /= (np.linalg.norm(f)+1e-12)
    cz = -f; uh = np.array([0.,0.,1.])
    if abs(np.dot(f,uh))>0.99: uh = np.array([0.,1.,0.])
    cx = np.cross(uh,cz); cx /= (np.linalg.norm(cx)+1e-12)
    cy = np.cross(cz,cx); R = np.stack([cx,cy,cz],axis=-1)
    q = ScipyR.from_matrix(R).as_quat()
    sim.model.cam_pos[cam_id] = new_pos
    sim.model.cam_quat[cam_id] = np.array([q[3],q[0],q[1],q[2]])
    sim.forward()
    for _ in range(3): env.step(np.zeros(7, dtype=np.float32))
    obs = env.env._get_observations()
    return cv2.cvtColor(np.flipud(obs["agentview_image"]).copy(), cv2.COLOR_RGB2BGR)

# Train frames (default viewpoint, varied positions)
train_frames = []
for _ in range(6):
    dx = rng.uniform(-0.40, -0.01); dy = rng.uniform(-0.30, 0.30)
    img = cv2.resize(capture(0, 0, dx, dy), (thumb, thumb))
    img[:2,:] = (100,255,100); img[-2:,:] = (100,255,100)
    img[:,:2] = (100,255,100); img[:,-2:] = (100,255,100)
    train_frames.append(img)

# Test frames (varied viewpoints, varied positions)
test_frames = []
for theta_d in [7, 14, 25]:
    for phi_d in [0, 180]:
        dx = rng.uniform(-0.40, -0.01); dy = rng.uniform(-0.30, 0.30)
        img = cv2.resize(capture(theta_d, phi_d, dx, dy), (thumb, thumb))
        cv2.putText(img, f"{theta_d},{phi_d}", (3,13), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255,255,255), 1)
        img[:2,:] = (100,150,255); img[-2:,:] = (100,150,255)
        img[:,:2] = (100,150,255); img[:,-2:] = (100,150,255)
        test_frames.append(img)

env.close()

# Combine into image
def stack(frames, cols=3):
    rows = []
    for r in range(0, len(frames), cols):
        row = frames[r:r+cols]
        while len(row) < cols: row.append(np.zeros((thumb,thumb,3), dtype=np.uint8))
        rows.append(np.concatenate(row, axis=1))
    return np.concatenate(rows, axis=0)

lbl_h = 22
def label(grid, text, color):
    l = np.zeros((lbl_h, grid.shape[1], 3), dtype=np.uint8)
    cv2.putText(l, text, (4, lbl_h-5), cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1)
    return np.vstack([l, grid])

tp = label(stack(train_frames), "TRAIN: default view, random positions", (100,255,100))
ep = label(stack(test_frames), "TEST: all viewpoints, random positions", (100,150,255))
h = max(tp.shape[0], ep.shape[0])
if tp.shape[0]<h: tp = np.vstack([tp, np.zeros((h-tp.shape[0],tp.shape[1],3),dtype=np.uint8)])
if ep.shape[0]<h: ep = np.vstack([ep, np.zeros((h-ep.shape[0],ep.shape[1],3),dtype=np.uint8)])
polar = cv2.imread(f"{out_dir}/polar_plot.png")
ph = h; pw = int(polar.shape[1]*ph/polar.shape[0])
polar = cv2.resize(polar, (pw, ph))
sep = np.zeros((h, 8, 3), dtype=np.uint8)
combined = np.concatenate([polar, sep, tp, sep, ep], axis=1)
cv2.imwrite(f"{out_dir}/distribution_overview.png", combined)
print("Saved distribution overview")
