#!/usr/bin/env python3
"""Compare norm stats between simulation data and DROID checkpoint.

Loads data directly from HDF5 and applies DeltaActions transform manually
to match what the training pipeline does.
"""

import argparse
import json
import h5py
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Parse arguments
parser = argparse.ArgumentParser(description="Compare norm stats between simulation data and DROID checkpoint.")
parser.add_argument("--hdf5-file", type=str, default="/home/arhanjain/.cache/openpi/tri-ml-datasets-uw2/arhanjain/rollout_datasets/cube_on_plate_1k.hdf5")
parser.add_argument("--droid-norm-stats", type=str, default="/home/arhanjain/.cache/openpi/openpi-assets/checkpoints/pi05_droid_jointpos/assets/droid/norm_stats.json")
parser.add_argument("--no-states", action="store_true", help="Disable state statistics comparison")
parser.add_argument("--sample-every", type=int, default=1, help="Sample every Nth demo (1 = all demos)")
args = parser.parse_args()

# Load DROID norm stats
print(f"Loading DROID norm stats from {args.droid_norm_stats}...")
with open(args.droid_norm_stats) as f:
    droid_stats = json.load(f)["norm_stats"]

# Load data from HDF5
print(f"Loading data from {args.hdf5_file}...")

all_actions = []  # Raw absolute actions
all_delta_actions = []  # After DeltaActions transform
all_states = []

with h5py.File(args.hdf5_file, "r") as f:
    demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")],
                       key=lambda x: int(x.split("_")[1]))
    print(f"Found {len(demo_keys)} demos")
    
    sample_demos = demo_keys[::args.sample_every]
    print(f"Sampling from {len(sample_demos)} demos...")
    
    for demo_key in sample_demos:
        demo = f[f"data/{demo_key}"]
        
        # Load actions - check different possible keys
        actions = None
        for action_key in ["action/droid_joint_pos_action", "actions"]:
            if action_key in demo:
                actions = np.array(demo[action_key])
                break
        
        if actions is None:
            continue
        
        # Load state (joint positions) - check different possible keys
        state = None
        for state_key in ["states/articulation/robot/joint_position"]:
            if state_key in demo:
                joint_pos = np.array(demo[state_key])  # Shape: (T+1, 15) - full robot state
                # Extract first 7 (arm) + gripper position
                # For DROID-style state: [arm_jp (7), gripper_pos (1)]
                arm_jp = joint_pos[:, :7]  # First 7 joints are arm
                # Gripper is typically joint 7 or we derive it from action
                gripper_pos = joint_pos[:, 7:8]  # Take joint 7 as gripper
                state = np.concatenate([arm_jp, gripper_pos], axis=-1)  # Shape: (T+1, 8)
                break
        
        if state is None:
            # Try alternate state format
            for arm_key in ["obs/vision/arm_joint_pos"]:
                if arm_key in demo:
                    arm_jp = np.array(demo[arm_key])
                    gripper_key = arm_key.replace("arm_joint_pos", "gripper_pos")
                    if gripper_key in demo:
                        gripper_pos = np.array(demo[gripper_key])
                        if gripper_pos.ndim == 1:
                            gripper_pos = gripper_pos[:, np.newaxis]
                        state = np.concatenate([arm_jp, gripper_pos], axis=-1)
                    break
        
        if state is None:
            print(f"Warning: Could not find state for {demo_key}, skipping")
            continue
        
        # For each action at timestep t, the corresponding state is state[t]
        # Actions have shape (T, 8), state has shape (T+1, 8)
        # Use state[:-1] to match actions
        T = min(actions.shape[0], state.shape[0] - 1)
        actions = actions[:T]
        action_states = state[:T]
        
        # Apply DeltaActions transform: delta = action - state for dims 0-6
        # Gripper (dim 7) stays absolute
        delta_actions = actions.copy()
        delta_actions[:, :7] = actions[:, :7] - action_states[:, :7]
        
        all_actions.append(actions)
        all_delta_actions.append(delta_actions)
        all_states.append(action_states)

# Concatenate
all_actions = np.concatenate(all_actions, axis=0)
all_delta_actions = np.concatenate(all_delta_actions, axis=0)
all_states = np.concatenate(all_states, axis=0)

print(f"\nData shapes:")
print(f"  Raw Actions (absolute): {all_actions.shape}")
print(f"  Delta Actions (action - state): {all_delta_actions.shape}")
print(f"  States: {all_states.shape}")
print(f"\nNote: DROID norm stats are computed on DELTA actions (dims 0-6 are action-state)")

# Compute statistics
def compute_stats(data):
    return {
        "mean": np.mean(data, axis=0),
        "std": np.std(data, axis=0),
        "q01": np.percentile(data, 1, axis=0),
        "q99": np.percentile(data, 99, axis=0),
        "min": np.min(data, axis=0),
        "max": np.max(data, axis=0),
    }

sim_action_stats = compute_stats(all_delta_actions)
sim_state_stats = compute_stats(all_states)

# Print comparison
n_action_dims = min(8, all_delta_actions.shape[-1])
n_state_dims = min(8, all_states.shape[-1])

print("\n" + "="*80)
print("DELTA ACTION STATISTICS COMPARISON (first 8 dims)")
print("="*80)

print(f"\n{'Dim':<5} {'DROID q01':>12} {'SIM q01':>12} {'DROID q99':>12} {'SIM q99':>12} {'SIM min':>12} {'SIM max':>12}")
print("-"*80)
for i in range(n_action_dims):
    droid_q01 = droid_stats["actions"]["q01"][i]
    droid_q99 = droid_stats["actions"]["q99"][i]
    sim_q01 = sim_action_stats["q01"][i]
    sim_q99 = sim_action_stats["q99"][i]
    sim_min = sim_action_stats["min"][i]
    sim_max = sim_action_stats["max"][i]
    print(f"{i:<5} {droid_q01:>12.4f} {sim_q01:>12.4f} {droid_q99:>12.4f} {sim_q99:>12.4f} {sim_min:>12.4f} {sim_max:>12.4f}")

if not args.no_states:
    print("\n" + "="*80)
    print("STATE STATISTICS COMPARISON (first 8 dims)")
    print("="*80)
    
    print(f"\n{'Dim':<5} {'DROID q01':>12} {'SIM q01':>12} {'DROID q99':>12} {'SIM q99':>12} {'SIM min':>12} {'SIM max':>12}")
    print("-"*80)
    for i in range(n_state_dims):
        droid_q01 = droid_stats["state"]["q01"][i]
        droid_q99 = droid_stats["state"]["q99"][i]
        sim_q01 = sim_state_stats["q01"][i]
        sim_q99 = sim_state_stats["q99"][i]
        sim_min = sim_state_stats["min"][i]
        sim_max = sim_state_stats["max"][i]
        print(f"{i:<5} {droid_q01:>12.4f} {sim_q01:>12.4f} {droid_q99:>12.4f} {sim_q99:>12.4f} {sim_min:>12.4f} {sim_max:>12.4f}")

# Normalized value ranges
print("\n" + "="*80)
print("NORMALIZED VALUE RANGES (using DROID quantile normalization)")
print("Formula: (x - q01) / (q99 - q01) * 2.0 - 1.0")
print("Expected range: [-1, 1]")
print("="*80)

def quantile_normalize(x, q01, q99):
    return (x - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0

print("\nDELTA ACTIONS normalized with DROID stats:")
print(f"{'Dim':<5} {'Norm Min':>12} {'Norm Max':>12} {'Status':>15}")
print("-"*50)
for i in range(n_action_dims):
    droid_q01 = droid_stats["actions"]["q01"][i]
    droid_q99 = droid_stats["actions"]["q99"][i]
    norm_min = quantile_normalize(sim_action_stats["min"][i], droid_q01, droid_q99)
    norm_max = quantile_normalize(sim_action_stats["max"][i], droid_q01, droid_q99)
    status = "OK" if -5 < norm_min and norm_max < 5 else "PROBLEMATIC!"
    print(f"{i:<5} {norm_min:>12.2f} {norm_max:>12.2f} {status:>15}")

if not args.no_states:
    print("\nSTATES normalized with DROID stats:")
    print(f"{'Dim':<5} {'Norm Min':>12} {'Norm Max':>12} {'Status':>15}")
    print("-"*50)
    for i in range(n_state_dims):
        droid_q01 = droid_stats["state"]["q01"][i]
        droid_q99 = droid_stats["state"]["q99"][i]
        norm_min = quantile_normalize(sim_state_stats["min"][i], droid_q01, droid_q99)
        norm_max = quantile_normalize(sim_state_stats["max"][i], droid_q01, droid_q99)
        status = "OK" if -5 < norm_min and norm_max < 5 else "PROBLEMATIC!"
        print(f"{i:<5} {norm_min:>12.2f} {norm_max:>12.2f} {status:>15}")

# Create visualization
n_rows = 2 if not args.no_states else 1
fig, axes = plt.subplots(n_rows, n_action_dims, figsize=(3 * n_action_dims, 4 * n_rows))

if n_rows == 1:
    axes = axes.reshape(1, -1)

# Row 1: Delta Action histograms
for i in range(n_action_dims):
    ax = axes[0, i]
    
    ax.hist(all_delta_actions[:, i], bins=50, alpha=0.7, color='steelblue', edgecolor='black', linewidth=0.5)
    
    droid_q01 = droid_stats["actions"]["q01"][i]
    droid_q99 = droid_stats["actions"]["q99"][i]
    ax.axvline(x=droid_q01, color='red', linestyle='--', linewidth=2, label=f'q01={droid_q01:.2f}')
    ax.axvline(x=droid_q99, color='red', linestyle='--', linewidth=2, label=f'q99={droid_q99:.2f}')
    
    ylim = ax.get_ylim()
    xlim = ax.get_xlim()
    ax.axvspan(xlim[0], droid_q01, alpha=0.2, color='red')
    ax.axvspan(droid_q99, xlim[1], alpha=0.2, color='red')
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    
    outside_bounds = np.sum((all_delta_actions[:, i] < droid_q01) | (all_delta_actions[:, i] > droid_q99))
    pct_outside = 100 * outside_bounds / len(all_delta_actions)
    
    dim_label = f'Δ Action {i}' if i < 7 else f'Gripper'
    ax.set_title(f'{dim_label}\n({pct_outside:.1f}% outside)')
    ax.set_xlabel('Value')
    if i == 0:
        ax.set_ylabel('Count')
    ax.legend(fontsize=6, loc='upper right')

# Row 2: State histograms
if not args.no_states:
    for i in range(n_state_dims):
        ax = axes[1, i]
        
        ax.hist(all_states[:, i], bins=50, alpha=0.7, color='seagreen', edgecolor='black', linewidth=0.5)
        
        droid_q01 = droid_stats["state"]["q01"][i]
        droid_q99 = droid_stats["state"]["q99"][i]
        ax.axvline(x=droid_q01, color='red', linestyle='--', linewidth=2, label=f'q01={droid_q01:.2f}')
        ax.axvline(x=droid_q99, color='red', linestyle='--', linewidth=2, label=f'q99={droid_q99:.2f}')
        
        ylim = ax.get_ylim()
        xlim = ax.get_xlim()
        ax.axvspan(xlim[0], droid_q01, alpha=0.2, color='red')
        ax.axvspan(droid_q99, xlim[1], alpha=0.2, color='red')
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)
        
        outside_bounds = np.sum((all_states[:, i] < droid_q01) | (all_states[:, i] > droid_q99))
        pct_outside = 100 * outside_bounds / len(all_states)
        
        ax.set_title(f'State {i}\n({pct_outside:.1f}% outside)')
        ax.set_xlabel('Value')
        if i == 0:
            ax.set_ylabel('Count')
        ax.legend(fontsize=6, loc='upper right')

plt.suptitle('SIM Data vs DROID Normalization Bounds\n(Delta Actions = action - state for dims 0-6)', fontsize=14, y=1.02)
plt.tight_layout()

output_path = Path(__file__).parent / 'norm_stats_comparison.png'
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"\nPlot saved to: {output_path}")
plt.show()
