"""
Action-conditioned replay but save only the first camera view (output width 1/3 —
one view instead of three side-by-side). Uses same 3-view forward pass.
"""
import numpy as np
import torch
import sys
import os
import datetime

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import wm_args
from argparse import ArgumentParser
import mediapy

exec(open(os.path.join(os.path.dirname(__file__), "rollout_replay_traj.py")).read().split("if __name__")[0])


if __name__ == "__main__":
    parser = ArgumentParser(description="Ctrl-World replay, first view only")
    parser.add_argument("--svd_model_path", type=str, default=None)
    parser.add_argument("--clip_model_path", type=str, default=None)
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument("--dataset_root_path", type=str, default=None)
    parser.add_argument("--dataset_meta_info_path", type=str, default=None)
    parser.add_argument("--dataset_names", type=str, default=None)
    parser.add_argument("--data_stat_path", type=str, default=None)
    parser.add_argument("--task_type", type=str, default="replay")
    parser.add_argument("--traj_id", type=str, default=None, help="Override: single traj id (e.g. 899)")
    parser.add_argument("--start_idx", type=int, default=None)
    parser.add_argument("--interact_num", type=int, default=None)
    args_new = parser.parse_args()

    args = wm_args(task_type=args_new.task_type)
    for k, v in args_new.__dict__.items():
        if v is not None:
            setattr(args, k, v)

    Agent = agent(args)
    interact_num = args.interact_num
    pred_step = args.pred_step
    num_history = args.num_history
    num_frames = args.num_frames

    if args_new.traj_id is not None:
        val_ids = [args_new.traj_id]
        start_idxs = [args_new.start_idx if args_new.start_idx is not None else 8]
        instructions = [""]
    else:
        val_ids = args.val_id
        start_idxs = args.start_idx
        instructions = args.instruction

    print("Replay with actions — first view only (1/3 width)")

    for val_id_i, text_i, start_idx_i in zip(val_ids, instructions, start_idxs):
        eef_gt, joint_pos_gt, video_dict, video_latents, instruction = Agent.get_traj_info(
            val_id_i, start_idx=start_idx_i, steps=int(pred_step * interact_num + 8)
        )
        text_i = instruction if Agent.args.text_cond else None
        print("Instruction:", instruction)

        his_cond = []
        his_eef = []
        first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0)
        for _ in range(Agent.args.num_history * 4):
            his_cond.append(first_latent)
            his_eef.append(eef_gt[0:1])

        video_to_save = []
        for i in range(interact_num):
            start_id = int(i * (pred_step - 1))
            end_id = start_id + pred_step
            video_latent_true = [v[start_id:end_id] for v in video_latents]
            cartesian_pose = eef_gt[start_id:end_id]
            history_idx = [0, 0, -8, -6, -4, -2]
            his_pose = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)
            action_cond = np.concatenate([his_pose, cartesian_pose], axis=0)
            his_cond_input = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
            current_latent = his_cond[-1]

            print(f"Step {i + 1}/{interact_num}")
            videos_cat, _, video_dict_pred, predicted_latents = Agent.forward_wm(
                action_cond, video_latent_true, current_latent, his_cond=his_cond_input, text=text_i
            )
            # Keep only first view: videos_cat is (T, 512, 768, 3), first view is width 0:256
            videos_cat = videos_cat[:, :, 0:256, :].astype(np.uint8)  # (T, 512, 256, 3)
            his_eef.append(cartesian_pose[pred_step - 1 : pred_step])
            his_cond.append(torch.cat([v[pred_step - 1] for v in predicted_latents], dim=1).unsqueeze(0))
            if i == interact_num - 1:
                video_to_save.append(videos_cat)
            else:
                video_to_save.append(videos_cat[: pred_step - 1])

        video = np.concatenate(video_to_save, axis=0)
        save_dir = getattr(args, "save_dir", "synthetic_traj")
        os.makedirs(f"{save_dir}/Rollouts_replay_first_view/video", exist_ok=True)
        uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        text_safe = (text_i or "")[:30].replace(" ", "_").replace(",", "").replace(".", "").replace("'", "").replace('"', "")
        filename = f"{save_dir}/Rollouts_replay_first_view/video/time_{uuid}_traj_{val_id_i}_first_view_{text_safe}.mp4"
        mediapy.write_video(filename, video, fps=4)
        print(f"Saved: {filename}")
