"""
Image-to-video rollout without action control: same as replay but feed zero actions
so the model generates video from the initial frame(s) only.
"""
import numpy as np
import torch
import sys
import os
import datetime
import json

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from accelerate import Accelerator
from decord import VideoReader, cpu
import mediapy

# Reuse the same agent and config as replay
from config import wm_args
from argparse import ArgumentParser

# Import agent class from replay script
exec(open(os.path.join(os.path.dirname(__file__), "rollout_replay_traj.py")).read().split("if __name__")[0])


if __name__ == "__main__":
    parser = ArgumentParser(description="Ctrl-World image-to-video rollout (zero actions)")
    parser.add_argument("--svd_model_path", type=str, default=None)
    parser.add_argument("--clip_model_path", type=str, default=None)
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument("--dataset_root_path", type=str, default=None)
    parser.add_argument("--dataset_meta_info_path", type=str, default=None)
    parser.add_argument("--dataset_names", type=str, default=None)
    parser.add_argument("--data_stat_path", type=str, default=None)
    parser.add_argument("--task_type", type=str, default="replay")
    parser.add_argument("--traj_id", type=str, default="899", help="Trajectory id from val set (e.g. 899)")
    parser.add_argument("--start_idx", type=int, default=8, help="Start frame index")
    parser.add_argument("--interact_num", type=int, default=6, help="Number of rollout steps (fewer = faster)")
    args_new = parser.parse_args()

    args = wm_args(task_type=args_new.task_type)
    for k, v in args_new.__dict__.items():
        if v is not None:
            setattr(args, k, v)
    if getattr(args_new, "interact_num", None) is not None:
        args.interact_num = args_new.interact_num

    Agent = agent(args)
    interact_num = args.interact_num
    pred_step = args.pred_step
    num_history = args.num_history
    num_frames = args.num_frames
    val_id_i = args_new.traj_id
    start_idx_i = args_new.start_idx
    text_i = ""

    print("Image-to-video rollout (ZERO actions)")
    print(f"Traj {val_id_i}, start_idx {start_idx_i}, interact_num {interact_num}")

    # Load one trajectory to get initial frames and video latents
    eef_gt, joint_pos_gt, video_dict, video_latents, instruction = Agent.get_traj_info(
        val_id_i, start_idx=start_idx_i, steps=int(pred_step * interact_num + 8)
    )
    text_i = instruction if Agent.args.text_cond else None
    print("Instruction (for cond only):", instruction)

    his_cond = []
    his_eef = []
    first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0)
    for _ in range(Agent.args.num_history * 4):
        his_cond.append(first_latent)
        his_eef.append(eef_gt[0:1])

    video_to_save = []
    for i in range(interact_num):
        start_id = int(i * (pred_step - 1))
        end_id = start_id + pred_step
        video_latent_true = [v[start_id:end_id] for v in video_latents]

        # Zero actions instead of recorded trajectory
        cartesian_pose = np.zeros((pred_step, 7), dtype=np.float32)
        history_idx = [0, 0, -8, -6, -4, -2]
        his_pose = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)
        action_cond = np.concatenate([his_pose, cartesian_pose], axis=0)

        his_cond_input = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
        current_latent = his_cond[-1]

        print(f"Step {i + 1}/{interact_num} (action = zeros)")
        videos_cat, true_videos, video_dict_pred, predicted_latents = Agent.forward_wm(
            action_cond, video_latent_true, current_latent, his_cond=his_cond_input, text=text_i
        )

        his_eef.append(cartesian_pose[pred_step - 1 : pred_step])
        his_cond.append(torch.cat([v[pred_step - 1] for v in predicted_latents], dim=1).unsqueeze(0))
        if i == interact_num - 1:
            video_to_save.append(videos_cat)
        else:
            video_to_save.append(videos_cat[: pred_step - 1])

    video = np.concatenate(video_to_save, axis=0)
    save_dir = getattr(args, "save_dir", "synthetic_traj")
    os.makedirs(f"{save_dir}/Rollouts_img2vid_no_action/video", exist_ok=True)
    uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{save_dir}/Rollouts_img2vid_no_action/video/time_{uuid}_traj_{val_id_i}_no_action.mp4"
    mediapy.write_video(filename, video, fps=4)
    print(f"Saved: {filename}")
