"""Simple PointNet model for predicting 6D joint states from pointclouds."""
import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import viser
from viser.extras import ViserUrdf
import yourdfpy
from scipy.spatial.transform import Rotation as R
import trimesh
from network_utils import PointNet

# Parse args early to know prediction mode
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, choices=["train", "test", "test_with_train"], default="train")
parser.add_argument("--prediction_type", type=str, choices=["joint_state", "end_effector"], default="joint_state", help="Predict joint states or end-effector pose")
parser.add_argument("--no_color", action="store_true", help="Ignore color and use only xyz coordinates (default: use color)")
parser.add_argument("--with_dino", action="store_true", help="Use DINO features (requires xyz+rgb to be enabled)")
args = parser.parse_args()
args.use_color = not args.no_color  # Default to True unless --no_color is specified

if args.with_dino and not args.use_color:
    raise ValueError("--with_dino requires --use_color (xyz+rgb must be enabled)")

def pose_4x4_to_6d(pose_4x4):
    """Convert 4x4 pose matrix to 6D (position + axis-angle)."""
    pos = pose_4x4[:3, 3]
    rot = pose_4x4[:3, :3]
    axis_angle = R.from_matrix(rot).as_rotvec()
    return np.concatenate([pos, axis_angle])

def pose_6d_to_4x4(pose_6d):
    """Convert 6D (position + axis-angle) to 4x4 pose matrix."""
    pos = pose_6d[:3]
    axis_angle = pose_6d[3:]
    rot = R.from_rotvec(axis_angle).as_matrix()
    pose = np.eye(4)
    pose[:3, :3] = rot
    pose[:3, 3] = pos
    return pose

# Load dataset
print("Loading dataset...")
processed_dir = "scratch/processed_grasp_dataset_keyboard"
sequences = sorted([d for d in os.listdir(processed_dir) if os.path.isdir(os.path.join(processed_dir, d))])

train_points, train_colors, train_targets = [], [], []
val_points, val_colors, val_targets = [], [], []
train_sequences, val_sequences = [], []
train_dino_features, val_dino_features = [], []

for i, seq_id in enumerate(sequences):
    seq_dir = os.path.join(processed_dir, seq_id)
    pointmap_path = os.path.join(seq_dir, "pointmap_start_cropped_fps.pt")
    
    if args.prediction_type == "end_effector":
        target_path = os.path.join(seq_dir, "gripper_pose_grasp.npy")
        if not os.path.exists(target_path):
            continue
        pose_4x4 = np.load(target_path)
        target = pose_4x4_to_6d(pose_4x4)
    else:
        target_path = os.path.join(seq_dir, "joint_states_grasp.npy")
        if not os.path.exists(target_path):
            continue
        target = np.load(target_path)
    
    if not os.path.exists(pointmap_path):
        continue
    
    pointmap = torch.load(pointmap_path)
    points = pointmap["points"].numpy()  # (N, 3)
    colors = pointmap["colors"].numpy()  # (N, 3) uint8 [0-255]
    
    # Use xyz only or xyz+rgb based on flag
    if args.use_color:
        point_features = np.concatenate([points, colors / 255.0], axis=1)  # (N, 6)
    else:
        point_features = points  # (N, 3)
    
    # Load DINO features if requested
    dino_features = None
    if args.with_dino:
        dino_path = os.path.join(seq_dir, "dino_features_fps.pt")
        if not os.path.exists(dino_path):
            print(f"  ⚠ Warning: DINO features not found for {seq_id}, skipping")
            continue
        dino_features = torch.load(dino_path).numpy()  # (N, 32)
    
    # Last 5 samples for validation
    if i >= len(sequences) - 5:
        val_points.append(point_features)
        val_colors.append(colors)
        val_targets.append(target)
        val_sequences.append(seq_id)
        if args.with_dino:
            val_dino_features.append(dino_features)
    else:
        train_points.append(point_features)
        train_colors.append(colors)
        train_targets.append(target)
        train_sequences.append(seq_id)
        if args.with_dino:
            train_dino_features.append(dino_features)

train_points = torch.from_numpy(np.stack(train_points)).float()  # (B, N, 3 or 6)
train_targets = torch.from_numpy(np.stack(train_targets)).float()  # (B, 6)
val_points = torch.from_numpy(np.stack(val_points)).float()  # (B, N, 3 or 6)
val_targets = torch.from_numpy(np.stack(val_targets)).float()  # (B, 6)

if args.with_dino:
    train_dino_features = torch.from_numpy(np.stack(train_dino_features)).float()  # (B, N, 32)
    val_dino_features = torch.from_numpy(np.stack(val_dino_features)).float()  # (B, N, 32)
else:
    train_dino_features = None
    val_dino_features = None

input_dim = 6 if args.use_color else 3
print(f"Train: {len(train_points)} samples, Val: {len(val_points)} samples")
print(f"Prediction type: {args.prediction_type}")
if args.with_dino:
    print(f"Input features: xyz+rgb (projected to 32D) + DINO (32D) = 64D")
else:
    print(f"Input features: {'xyz+rgb' if args.use_color else 'xyz only'}")

def train(model, train_data, val_data, prediction_type, use_color, with_dino=False, lr=4e-3):
    device = torch.device("mps")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    if with_dino:
        train_points, train_dino, train_targets = train_data
        val_points, val_dino, val_targets = val_data
        train_points = train_points.to(device)
        train_dino = train_dino.to(device)
        train_targets = train_targets.to(device)
        val_points = val_points.to(device)
        val_dino = val_dino.to(device)
        val_targets = val_targets.to(device)
    else:
        train_points, train_targets = train_data
        val_points, val_targets = val_data
        train_points, train_targets = train_points.to(device), train_targets.to(device)
        val_points, val_targets = val_points.to(device), val_targets.to(device)
        train_dino = None
        val_dino = None
    
    # Setup real-time plotting
    plt.ion()
    fig, ax = plt.subplots(figsize=(8, 5))
    train_losses, val_losses = [], []
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Progress')
    
    epoch=0
    while True:
        model.train()
        optimizer.zero_grad()
        if with_dino:
            pred = model(train_points, train_dino)
        else:
            pred = model(train_points)
        loss = criterion(pred, train_targets)
        loss.backward()
        optimizer.step()
        train_loss = loss.item()
        epoch += 1
        
        model.eval()
        with torch.no_grad():
            if with_dino:
                val_pred = model(val_points, val_dino)
            else:
                val_pred = model(val_points)
            val_loss = criterion(val_pred, val_targets).item()
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        # Update plot (only update every few epochs to reduce window stealing focus)
        if epoch % 5 == 0:  # Update every 5 epochs
            ax.clear()
            ax.plot(train_losses, label='Train', alpha=0.7)
            ax.plot(val_losses, label='Val', alpha=0.7)
            ax.legend()
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Loss')
            ax.set_title('Training Progress')
            fig.canvas.draw_idle()
            fig.canvas.flush_events()  # Update display without raising window
        
        if (epoch + 1) % 50 == 0: print(f"Epoch {epoch+1}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
        elif (epoch + 1) % 101 == 0: 
            color_suffix = "rgb" if use_color else "xyz"
            dino_suffix = "_dino" if with_dino else ""
            model_name = f"pointnet_predictor_{prediction_type}_{color_suffix}{dino_suffix}.pt"
            print(f"saving model to {model_name}")
            torch.save(model.state_dict(), f"scratch/policies/{model_name}")

if __name__ == "__main__":
    # Create model
    color_suffix = "rgb" if args.use_color else "xyz"
    dino_suffix = "_dino" if args.with_dino else ""
    model = PointNet(num_points=1024, input_dim=input_dim, output_dim=6, with_dino=args.with_dino)
    model_name = f"pointnet_predictor_{args.prediction_type}_{color_suffix}{dino_suffix}.pt"
    
    if args.mode == "train":
        os.makedirs("scratch/policies", exist_ok=True)
        if args.with_dino:
            train(model, (train_points, train_dino_features, train_targets), 
                  (val_points, val_dino_features, val_targets), 
                  args.prediction_type, args.use_color, with_dino=True)
        else:
            train(model, (train_points, train_targets), (val_points, val_targets), 
                  args.prediction_type, args.use_color, with_dino=False)
        torch.save(model.state_dict(), f"scratch/policies/{model_name}")
        print(f"Model saved to scratch/policies/{model_name}")
    else:
        os.makedirs("scratch/policies", exist_ok=True)
        model_path = f"scratch/policies/{model_name}"
        model.load_state_dict(torch.load(model_path))
        device = torch.device("mps")
        model = model.to(device).eval()
        
        # Choose sequences and points based on mode
        sequences_to_use = train_sequences if args.mode == "test_with_train" else val_sequences
        points_to_use = train_points if args.mode == "test_with_train" else val_points
        targets_to_use = train_targets if args.mode == "test_with_train" else val_targets
        
        if args.with_dino:
            dino_to_use = train_dino_features if args.mode == "test_with_train" else val_dino_features
            points_gpu = points_to_use.to(device)
            dino_gpu = dino_to_use.to(device)
            with torch.no_grad(): pred = model(points_gpu, dino_gpu)
        else:
            points_gpu = points_to_use.to(device)
            with torch.no_grad(): pred = model(points_gpu)
        loss = nn.MSELoss()(pred, targets_to_use.to(device)).item()
        print(f"Test Loss: {loss:.6f}")
        
        # Viser visualization
        print("\nLaunching Viser visualization...")
        server = viser.ViserServer()
        
        # Load gripper mesh if in end_effector mode
        gripper_mesh = None
        if args.prediction_type == "end_effector":
            fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
            gripper_mesh = trimesh.load(fixed_gripper_stl_path)
            if isinstance(gripper_mesh, trimesh.Scene):
                gripper_mesh = list(gripper_mesh.geometry.values())[0]
            bounds = gripper_mesh.bounds
            max_extent = np.max(bounds[1] - bounds[0])
            if max_extent > 1.0:
                gripper_mesh.apply_scale(0.001)
        else:
            urdf_path = "/Users/cameronsmith/Projects/robotics_testing/calibration_testing/so_100_arm/urdf/so_100_arm.urdf"
            urdf = yourdfpy.URDF.load(urdf_path)
            viser_urdf = ViserUrdf(server, urdf_or_path=urdf, load_meshes=True, load_collision_meshes=False)
            mujoco_so100_offset = np.array([0, -1.57, 1.57, 1.57, -1.57, 0])
        
        # Load pointclouds
        pc_full, pc_cropped, pc_cropped_fps = [], [], []
        for seq_id in sequences_to_use:
            seq_dir = os.path.join(processed_dir, seq_id)
            pointmap_full = torch.load(os.path.join(seq_dir, "pointmap_start.pt"))
            pointmap_cropped = torch.load(os.path.join(seq_dir, "pointmap_start_cropped.pt"))
            pointmap_cropped_fps = torch.load(os.path.join(seq_dir, "pointmap_start_cropped_fps.pt"))
            pc_full.append((pointmap_full["points"].numpy(), pointmap_full["colors"].numpy()))
            pc_cropped.append((pointmap_cropped["points"].numpy(), pointmap_cropped["colors"].numpy()))
            pc_cropped_fps.append((pointmap_cropped_fps["points"].numpy(), pointmap_cropped_fps["colors"].numpy()))
        
        def update_vis(idx):
            idx = int(idx)
            pred_6d = pred[idx].cpu().numpy()
            
            if args.prediction_type == "end_effector":
                # Convert 6D prediction to 4x4 pose and display gripper mesh
                pred_pose = pose_6d_to_4x4(pred_6d)
                pos = pred_pose[:3, 3]
                rot = pred_pose[:3, :3]
                quat = R.from_matrix(rot).as_quat()  # (x, y, z, w)
                
                if gripper_mesh is not None:
                    vertices_homogeneous = np.hstack([gripper_mesh.vertices, np.ones((gripper_mesh.vertices.shape[0], 1))])
                    transformed_vertices = (pred_pose @ vertices_homogeneous.T).T[:, :3]
                    
                    try:
                        server.scene.add_mesh_trimesh(
                            name="/predicted_gripper",
                            mesh=gripper_mesh,
                            wxyz=quat[[3, 0, 1, 2]],  # (w, x, y, z)
                            position=pos,
                        )
                    except:
                        server.scene.add_mesh(
                            name="/predicted_gripper",
                            vertices=transformed_vertices.astype(np.float32),
                            faces=gripper_mesh.faces.astype(np.int32),
                            color=(100, 150, 200, 255),
                        )
            else:
                # Update robot URDF with joint state prediction
                viser_urdf.update_cfg((pred_6d - mujoco_so100_offset).astype(np.float32))
            
            server.scene.add_point_cloud("/full", pc_full[idx][0].astype(np.float32), pc_full[idx][1].astype(np.uint8), point_size=0.001)
            server.scene.add_point_cloud("/cropped", pc_cropped[idx][0].astype(np.float32), pc_cropped[idx][1].astype(np.uint8), point_size=0.002)
            server.scene.add_point_cloud("/cropped_fps", pc_cropped_fps[idx][0].astype(np.float32), pc_cropped_fps[idx][1].astype(np.uint8), point_size=0.003)
        
        slider = server.gui.add_slider("sample", 0, len(sequences_to_use)-1, initial_value=0, step=1)
        slider.on_update(lambda _: update_vis(slider.value))
        update_vis(0)
        
        print(f"Viser running at http://localhost:8080")
        import time
        while True:
            time.sleep(0.1)
        

