import wandb
import dataclasses
import numpy as np
import tyro
import torch
import torch.nn as nn
from pathlib import Path
from openpi.training.wds_dataset import WDSDataset, WDSDatasetConfig, _load_manifest


@dataclasses.dataclass
class TrainConfig:
    # Training
    num_epochs: int = 10
    batch_size: int = 256
    lr: float = 1e-3
    log_interval: int = 10
    exp_name: str = "qp_net"
    wandb_enabled: bool = True
    checkpoint_dir: str = "checkpoints/qp_net"
    save_every_n_epochs: int = 1
    resume_from: str | None = None  # Path to checkpoint to resume from (jit or state_dict)

    # Data
    manifest_path: str = "s3://tri-ml-datasets-uw2/vla_foundry_datasets/v0.4.2-sim/manifest.jsonl"
    num_workers: int = 16
    shuffle_buffer: int = 2000
    val_frac: float = 0.05


# All fields go into action_fields so we get the full [B, T=2, dim] window.
# With past=0, future=1, we get timesteps [t, t+1].
WDS_CONFIG = WDSDatasetConfig(
    action_fields=[
        # ee commanded (what we condition on at t)
        "robot__action__poses__left::panda__xyz",       # (3,)
        "robot__action__poses__right::panda__xyz",      # (3,)
        "robot__action__poses__left::panda__rot_6d",    # (6,)
        "robot__action__poses__right::panda__rot_6d",   # (6,)
        # q commanded (input at t, target at t+1)
        "robot__desired__joint_position__left::panda",  # (7,)
        "robot__desired__joint_position__right::panda",  # (7,)
    ],
    proprioception_fields=[],
    camera_names=[],
    image_indices=[],
    lowdim_past_timesteps=0,
    lowdim_future_timesteps=1,
)

# Input/output schema: list of (field_name, dim)
INPUT_SCHEMA: list[tuple[str, int]] = [
    ("robot__desired__joint_position__left::panda", 7),
    ("robot__desired__joint_position__right::panda", 7),
    ("robot__action__poses__left::panda__xyz", 3),
    ("robot__action__poses__right::panda__xyz", 3),
    ("robot__action__poses__left::panda__rot_6d", 6),
    ("robot__action__poses__right::panda__rot_6d", 6),
]

OUTPUT_SCHEMA: list[tuple[str, int]] = [
    ("robot__desired__joint_position__left::panda", 7),
    ("robot__desired__joint_position__right::panda", 7),
]


def extract_batch(batch: dict) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
    """Extract (input_dict, target_dict) from a WDS batch.

    Input dict:  each field at t=0
    Target dict: q_commanded fields at t=1
    """
    actions = batch["actions"]

    inputs: dict[str, torch.Tensor] = {}
    for key, _ in INPUT_SCHEMA:
        inputs[key] = actions[key][:, 0]  # [B, dim] at t=0

    targets: dict[str, torch.Tensor] = {}
    for key, _ in OUTPUT_SCHEMA:
        targets[key] = actions[key][:, 1]  # [B, dim] at t=1

    return inputs, targets


class QPNet(nn.Module):
    """QP network with dict-based I/O.

    Stores input/output key names and dimensions so the checkpoint is
    self-describing. Use torch.jit.script to save a self-contained
    checkpoint loadable without this source file.
    """

    def __init__(
        self,
        input_keys: list[str],
        input_dims: list[int],
        output_keys: list[str],
        output_dims: list[int],
        hidden_dim: int = 256,
    ):
        super().__init__()
        # Store schema as plain lists (TorchScript-compatible)
        self.input_keys = input_keys
        self.input_dims = input_dims
        self.output_keys = output_keys
        self.output_dims = output_dims

        input_dim = sum(input_dims)
        output_dim = sum(output_dims)

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        # Concat inputs in schema order
        parts: list[torch.Tensor] = []
        for key in self.input_keys:
            parts.append(inputs[key])
        x = torch.cat(parts, dim=-1)

        y = self.net(x)

        # Split outputs by schema
        outputs: dict[str, torch.Tensor] = {}
        offset: int = 0
        for i, key in enumerate(self.output_keys):
            dim = self.output_dims[i]
            outputs[key] = y[..., offset:offset + dim]
            offset += dim
        return outputs


def build_model() -> QPNet:
    return QPNet(
        input_keys=[k for k, _ in INPUT_SCHEMA],
        input_dims=[d for _, d in INPUT_SCHEMA],
        output_keys=[k for k, _ in OUTPUT_SCHEMA],
        output_dims=[d for _, d in OUTPUT_SCHEMA],
    )


def split_manifest(manifest_path: str, val_frac: float, seed: int = 42) -> tuple[list[dict], list[dict]]:
    """Load manifest and split into train/val shard lists."""
    manifest = _load_manifest(manifest_path)
    rng = np.random.default_rng(seed)
    indices = rng.permutation(len(manifest))
    n_val = max(1, int(len(manifest) * val_frac))
    val_entries = [manifest[i] for i in indices[:n_val]]
    train_entries = [manifest[i] for i in indices[n_val:]]
    train_samples = sum(e["num_sequences"] for e in train_entries)
    val_samples = sum(e["num_sequences"] for e in val_entries)
    print(f"Split: {len(train_entries)} train shards ({train_samples} samples), "
          f"{len(val_entries)} val shards ({val_samples} samples)")
    return train_entries, val_entries


def make_dataset(config: TrainConfig, manifest_entries: list[dict], shuffle: bool) -> WDSDataset:
    return WDSDataset(
        manifest_path=config.manifest_path,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=shuffle,
        shuffle_buffer=config.shuffle_buffer,
        config=WDS_CONFIG,
        manifest_entries=manifest_entries,
    )


def dict_mse_loss(pred: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor:
    """MSE loss across all output fields."""
    pred_cat = torch.cat([pred[k] for k in sorted(target.keys())], dim=-1)
    target_cat = torch.cat([target[k] for k in sorted(target.keys())], dim=-1)
    return nn.functional.mse_loss(pred_cat, target_cat)


@torch.no_grad()
def validate(model: nn.Module, config: TrainConfig, val_entries: list[dict], device: torch.device) -> float:
    model.eval()
    val_dataset = make_dataset(config, val_entries, shuffle=False)
    total_loss = 0.0
    n = 0
    for batch in val_dataset:
        inputs, targets = extract_batch(batch)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        targets = {k: v.to(device) for k, v in targets.items()}
        pred = model(inputs)
        total_loss += dict_mse_loss(pred, targets).item()
        n += 1
    model.train()
    return total_loss / max(n, 1)


def main(config: TrainConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if config.wandb_enabled:
        wandb.init(project="sim-improvement", name=config.exp_name)

    ckpt_dir = Path(config.checkpoint_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # Split shards into train/val (deterministic, shard-level)
    train_entries, val_entries = split_manifest(config.manifest_path, config.val_frac)

    # Model
    model = build_model().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    start_epoch = 0

    # Resume from checkpoint
    print(f"Resume from: {config.resume_from}")
    resume_path = Path(config.resume_from) if config.resume_from else ckpt_dir / "latest.pt"
    if resume_path.exists():
        print(f"Resuming from {resume_path}")
        loaded_weights = False

        # Try JIT first
        try:
            loaded = torch.jit.load(str(resume_path), map_location=device)
            jit_sd = loaded.state_dict()
            print(f"  JIT state_dict keys: {list(jit_sd.keys())[:5]}...")
            model.load_state_dict(jit_sd, strict=False)
            loaded_weights = True
            print("  loaded weights from TorchScript")
        except Exception as e:
            print(f"  JIT load failed ({e}), trying plain checkpoint...")

        # Fall back to plain state_dict
        if not loaded_weights:
            ckpt = torch.load(resume_path, map_location=device, weights_only=False)
            if isinstance(ckpt, dict) and "model" in ckpt:
                state = ckpt["model"]
                print(f"  plain ckpt keys: {list(state.keys())[:5]}...")
                model.load_state_dict(state, strict=False)
                if "optimizer" in ckpt:
                    optimizer.load_state_dict(ckpt["optimizer"])
                if "epoch" in ckpt:
                    start_epoch = ckpt["epoch"] + 1
            else:
                # Raw state_dict
                state = ckpt if isinstance(ckpt, dict) else ckpt.state_dict()
                print(f"  raw state_dict keys: {list(state.keys())[:5]}...")
                model.load_state_dict(state, strict=False)
            loaded_weights = True
            print("  loaded weights from plain checkpoint")

        # Optimizer / epoch state (saved alongside JIT checkpoints)
        opt_path = resume_path.parent / (resume_path.stem + "_optim.pt")
        if opt_path.exists():
            opt_ckpt = torch.load(opt_path, map_location=device, weights_only=False)
            optimizer.load_state_dict(opt_ckpt["optimizer"])
            if "epoch" in opt_ckpt:
                start_epoch = opt_ckpt["epoch"] + 1
            print(f"  loaded optimizer from {opt_path.name}")

        print(f"  starting at epoch {start_epoch}")
    else:
        print(f"No checkpoint found at {resume_path}, training from scratch")

    print(f"QPNet: {sum(p.numel() for p in model.parameters())} params")
    print(f"Inputs:  {list(zip(model.input_keys, model.input_dims))}")
    print(f"Outputs: {list(zip(model.output_keys, model.output_dims))}")

    for epoch in range(start_epoch, config.num_epochs):
        # Recreate dataset each epoch (WDS iterates shards once)
        dataset = make_dataset(config, train_entries, shuffle=True)

        epoch_loss = 0.0
        num_batches = 0

        for batch in dataset:
            inputs, targets = extract_batch(batch)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            targets = {k: v.to(device) for k, v in targets.items()}

            pred = model(inputs)
            loss = dict_mse_loss(pred, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

            if num_batches % config.log_interval == 0:
                avg = epoch_loss / num_batches
                print(f"  epoch {epoch} | batch {num_batches} | loss {loss.item():.6f} | avg {avg:.6f}")
                if config.wandb_enabled:
                    wandb.log({"train/loss": loss.item(), "train/avg_loss": avg})

        avg_loss = epoch_loss / max(num_batches, 1)

        # Validation
        val_loss = validate(model, config, val_entries, device)
        print(f"Epoch {epoch}: train_loss={avg_loss:.6f} | val_loss={val_loss:.6f} ({num_batches} batches)")
        if config.wandb_enabled:
            wandb.log({"epoch": epoch, "train/epoch_loss": avg_loss, "val/loss": val_loss})

        if (epoch + 1) % config.save_every_n_epochs == 0 or epoch == config.num_epochs - 1:
            # Save model as TorchScript (self-contained: class + weights + schema)
            scripted = torch.jit.script(model)
            scripted.save(str(ckpt_dir / f"epoch_{epoch:04d}.pt"))
            scripted.save(str(ckpt_dir / "latest.pt"))
            # Save optimizer state separately (not scriptable)
            torch.save({"optimizer": optimizer.state_dict(), "epoch": epoch},
                        ckpt_dir / "latest_optim.pt")
            print(f"Saved checkpoint: epoch_{epoch:04d}.pt")


if __name__ == "__main__":
    config = tyro.cli(TrainConfig)
    main(config)
