
from openpi.training import config as config_pi
from openpi.policies import policy_config
from openpi_client import image_tools
# from openpi.shared import download

import numpy as np


from accelerate import Accelerator
import torch
from diffusers import StableVideoDiffusionPipeline
import numpy as np
# import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops
from accelerate import Accelerator
import datetime
import os
from accelerate.logging import get_logger
from tqdm.auto import tqdm
import wandb
import json
from decord import VideoReader, cpu
import swanlab
import mediapy
import sys
from scipy.spatial.transform import Rotation as R

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline
from models.ctrl_world import CrtlWorld
from models.utils import key_board_control, get_fk_solution
    

class agent():
    def __init__(self,args):
          
        # args = Args()
        args.val_model_path = args.ckpt_path
        self.args = args
        self.accelerator = Accelerator()
        self.device = self.accelerator.device
        self.dtype = args.dtype

        # load pi policy
        if 'pi05' in args.policy_type:
            config = config_pi.get_config("pi05_droid")
            # checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid' 
        elif 'pi0fast' in args.policy_type:
            config = config_pi.get_config("pi0fast_droid")
            # checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0fast_droid'
        elif 'pi0' in args.policy_type:
            config = config_pi.get_config("pi0_droid")
            # checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0_droid'
        else:
            raise ValueError(f"Unknown policy type: {args.policy_type}")
        self.policy = policy_config.create_trained_policy(config, args.pi_ckpt)

        # load ctrl-world model

        self.model = CrtlWorld(args)
        self.model.load_state_dict(torch.load(args.val_model_path))
        self.model.to(self.accelerator.device).to(self.dtype)
        self.model.eval()
        print("load world model success")
        with open(f"{args.data_stat_path}", 'r') as f:
            data_stat = json.load(f)
            self.state_p01 = np.array(data_stat['state_01'])[None,:]
            self.state_p99 = np.array(data_stat['state_99'])[None,:]
        
        # Since the official Pi-Droid model output joint velocity, and crtl-world is train on cartesian space, we need to load an light-weight adapter to transform joint velocity action into cartesian pose action. 
        if args.action_adapter is not None:
            from models.action_adapter.train2 import Dynamics
            self.dynamics_model = Dynamics(action_dim=7, action_num=15, hidden_size=512).to(self.device)
            self.dynamics_model.load_state_dict(torch.load(args.action_adapter, map_location=self.device))        

    def normalize_bound(
        self,
        data: np.ndarray,
        data_min: np.ndarray,
        data_max: np.ndarray,
        clip_min: float = -1,
        clip_max: float = 1,
        eps: float = 1e-8,
    ) -> np.ndarray:
        ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
        return np.clip(ndata, clip_min, clip_max)


    def get_traj_info(self, id, start_idx=0, steps=8,skip=1):
        val_dataset_dir = self.args.val_dataset_dir
        num_frames = steps
        annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json"
        with open(annotation_path) as f:
            anno = json.load(f)
            try:
                length = len(anno['action'])
            except:
                length = anno["video_length"]
        frames_ids = np.arange(start_idx, start_idx + num_frames * skip, skip)
        max_ids = np.ones_like(frames_ids) * (length - 1)
        frames_ids = np.min([frames_ids, max_ids], axis=0).astype(int)
        print("Ground truth frames ids", frames_ids)

        # get action and joint pos
        instruction = anno['texts'][0]
        car_action = np.array(anno['states'])
        car_action = car_action[frames_ids]
        joint_pos = np.array(anno['joints'])
        joint_pos = joint_pos[frames_ids]

        # get videos
        video_dict =[]
        video_latent = []
        for id in range(len(anno['videos'])):
            video_path = anno['videos'][id]['video_path']
            video_path = f"{val_dataset_dir}/{video_path}"
            # load videos from all views
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
            try:
                true_video = vr.get_batch(range(length)).asnumpy()
            except:
                true_video = vr.get_batch(range(length)).numpy()
            true_video = true_video[frames_ids]
            video_dict.append(true_video)

            # encode video
            device = self.device
            true_video = torch.from_numpy(true_video).to(self.dtype).to(device)
            x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1
            vae = self.model.pipeline.vae
            with torch.no_grad():
                batch_size = 32
                latents = []
                for i in range(0, len(x), batch_size):
                    batch = x[i:i+batch_size]
                    latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor)
                    latents.append(latent)
                x = torch.cat(latents, dim=0)
    
            video_latent.append(x)

        
        return car_action, joint_pos, video_dict, video_latent, instruction

    def forward_wm(self, action_cond, video_latent_true, video_latent_cond, his_cond=None, text=None):
        # action_cond, video_latent_true, current_latent, his_cond=his_latent,text=text_i
        args = self.args
        image_cond = video_latent_cond

        # action should be normed
        action_cond = self.normalize_bound(action_cond, self.state_p01, self.state_p99, clip_min=-1, clip_max=1)
        action_cond = torch.tensor(action_cond).unsqueeze(0).to(self.device).to(self.dtype)
        assert image_cond.shape[1:] == (4, 72, 40)
        assert action_cond.shape[1:] == (args.num_frames+args.num_history, args.action_dim)


        # predict future frames
        with torch.no_grad():
            bsz = action_cond.shape[0]
            if text is not None:
                text_token = self.model.action_encoder(action_cond, text, self.model.tokenizer, self.model.text_encoder)
            else:
                text_token = self.model.action_encoder(action_cond)           
            pipeline = self.model.pipeline
            
            _, latents = CtrlWorldDiffusionPipeline.__call__(
                pipeline,
                image=image_cond,
                text=text_token,
                width=args.width,
                height=int(args.height*3),
                num_frames=args.num_frames,
                history=his_cond,
                num_inference_steps=args.num_inference_steps,
                decode_chunk_size=args.decode_chunk_size,
                max_guidance_scale=args.guidance_scale,
                fps=args.fps,
                motion_bucket_id=args.motion_bucket_id,
                mask=None,
                output_type='latent',
                return_dict=False,
                frame_level_cond=True,
            )
        latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)


        # decode ground truth video
        true_video = torch.stack(video_latent_true, dim=0) # (bsz, 8,32,32)
        decoded_video = []
        bsz,frame_num = true_video.shape[:2]
        true_video = true_video.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,true_video.shape[0],args.decode_chunk_size):
            chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        true_video = torch.cat(decoded_video,dim=0)
        true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:])
        true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255)
        true_video = true_video.detach().to(torch.float32).cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)

        # decode predicted video
        decoded_video = []
        bsz,frame_num = latents.shape[:2]
        x = latents.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,x.shape[0],args.decode_chunk_size):
            chunk = x[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        videos = torch.cat(decoded_video,dim=0)
        videos = videos.reshape(bsz,frame_num,*videos.shape[1:])
        videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255)
        videos = videos.detach().to(torch.float32).cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8)

        # concatenate true videos and video
        videos_cat = np.concatenate([true_video,videos],axis=-3) # (3, 8, 256, 256, 3)
        videos_cat = np.concatenate([video for video in videos_cat],axis=-2).astype(np.uint8) 

        return videos_cat, true_video, videos, latents  # np.uint8:(3, 8, 128, 256, 3) or (3, 8, 192, 320, 3)

    def forward_policy(self, videos, state, joints, text, time_step=1):
        
        # inference policy
        image1 = videos[1]
        image2 = videos[2]
        image1 = torch.from_numpy(image1).to(torch.uint8)  # convert to torch tensor
        image2 = torch.from_numpy(image2).to(torch.uint8)  # convert to torch tensor
        assert image1.shape == (192, 320, 3), "Image 1 shape should be (192, 320, 3), got {}".format(image1.shape)
        image1 = torch.nn.functional.interpolate(image1.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
        image2 = torch.nn.functional.interpolate(image2.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
        image1 = image1.numpy()  # convert back to numpy array
        image2 = image2.numpy()  # convert back to numpy array
        example = {
            "observation/exterior_image_1_left": image_tools.resize_with_pad(image1, 224, 224),
            "observation/wrist_image_left": image_tools.resize_with_pad(image2, 224, 224),
            "observation/joint_position": joints[:7],
            "observation/gripper_position": joints[-1:],
            "prompt": text,
        }
        action_chunk = self.policy.infer(example)["actions"] #(10,8) velocity

        # action adapater
        current_joint = joints[None,:][:,:7]
        current_gripper = joints[None,:][:,7:]
        if 'pi05' in self.args.policy_type:
            idx = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]  # for dynamics model, we need 15 steps
        else:
            idx = [0,1,2,3,4,5,6,7,8,9,9,9,9,9,9]
        # policy output joint velocity and gripper position
        joint_vel = action_chunk[:,:7] # (15, 7)
        gripper_pos = action_chunk[:,7:] # (15, 1)
        joint_vel = joint_vel[idx]  # (15, 7)
        gripper_pos = gripper_pos[idx]  # (15, 1)
        gripper_max = self.args.gripper_max
        z_min = self.args.z_min


        gripper_pos = np.clip(gripper_pos, 0, gripper_max)
        # calculate future joint positions
        joint_pos = self.dynamics_model(current_joint, joint_vel,None, training=False)
        # fk
        state_fk = []
        joint_pos = np.concatenate([current_joint, joint_pos], axis=0)[:15]  # (15, 7)
        gripper_pos = np.concatenate([current_gripper, gripper_pos], axis=0)[:15]  # (15, 1)
        joint_vel = joint_vel  # (15, 7)
        for i in range(joint_pos.shape[0]):
            current_state_fk = get_fk_solution(joint_pos[i,:7])
            xyz = current_state_fk[:3, 3]
            # clip z axis to avoid collision with table
            xyz[2] = np.clip(xyz[2], z_min, None)
            rotation_matrix = current_state_fk[:3, :3]
            r = R.from_matrix(rotation_matrix)
            euler = r.as_euler('xyz') 
            state_fk.append(np.concatenate([xyz, euler, gripper_pos[i]], axis=0))
        state_fk = np.array(state_fk) # (15,7)

        # prepare output
        skip = self.args.policy_skip_step
        valid_num = int(skip*(self.args.pred_step-1))
        policy_in_out = {
            'joint_pos': joint_pos[:valid_num],  # (12, 7)
            'joint_vel': joint_vel[:valid_num],  # (12, 7)
            'state_fk': state_fk[:valid_num],  # (12, 7)
        }
        state_fk_skip = state_fk[::skip][:self.args.pred_step]  # (5, 7)
        joint_pos_skip = joint_pos[::skip][:self.args.pred_step]  # (5, 7)
        joint_pos_skip = np.concatenate([joint_pos_skip, state_fk_skip[:,-1:]], axis=-1) # (5, 8) add gripper pos

        return policy_in_out, joint_pos_skip, state_fk_skip

    
if __name__ == "__main__":
    from config_eval import wm_args
    from argparse import ArgumentParser
    parser = ArgumentParser()
    parser.add_argument('--svd_model_path', type=str, default=None)
    parser.add_argument('--clip_model_path', type=str, default=None)
    parser.add_argument('--ckpt_path', type=str, default=None)
    parser.add_argument('--dataset_root_path', type=str, default=None)
    parser.add_argument('--dataset_meta_info_path', type=str, default=None)
    parser.add_argument('--dataset_names', type=str, default=None)
    parser.add_argument('--task_type', type=str, default=None)
    parser.add_argument('--pi_ckpt', type=str, default='/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid')
    args_new = parser.parse_args()

    args = wm_args(task_type=args_new.task_type)

    def merge_args(cfg, cli_args):
        for k, v in vars(cli_args).items():
            if v is not None:
                setattr(cfg, k, v)
        return cfg

    args = merge_args(args, args_new)

    # create agent
    Agent = agent(args)
    interact_num = args.interact_num
    pred_step = args.pred_step
    num_history = args.num_history
    num_frames = args.num_frames
    history_idx = args.history_idx

    # run len(val_id) trajectory
    for val_id_i, text_i, start_idx_i in zip(args.val_id, args.instruction, args.start_idx):

        # get initial state and groud truth
        id = val_id_i
        eef_gt, joint_pos_gt, video_dict, video_latents,_ = Agent.get_traj_info(val_id_i, start_idx=start_idx_i, steps=int(pred_step*interact_num+8))
        print("text_i:",text_i, "eef pose at t=0", eef_gt[0], "joint at t=0", joint_pos_gt[0])

        # initialize all history buffer
        video_to_save, info_to_save = [], []
        his_cond, his_joint, his_eef = [], [], []
        first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0)  # (1, 4, 72, 40)
        assert first_latent.shape == (1, 4, 72, 40), f"Expected first_latent shape (1, 4, 72, 40), got {first_latent.shape}"
        for i in range(Agent.args.num_history*4):
            his_cond.append(first_latent)  # (1, 4, 72, 40)
            his_joint.append(joint_pos_gt[0:1])  # (1, 7)
            his_eef.append(eef_gt[0:1])  # (1, 7)
        video_dict_pred = [v[0:1] for v in video_dict]


        # start rollout
        for i in range(interact_num):
            # get ground truth video latents
            # video_latent_true = [v[int(i*pred_step):int(i*pred_step+num_frames)] for v in video_latents]
            start_id = int(i*(pred_step-1))
            end_id = start_id + pred_step
            video_latent_true = [v[start_id:end_id] for v in video_latents]
            
            print("################ policy forward ####################")
            # prepare input for policy
            current_joint = his_joint[-1][0] # (1, 8)
            current_pose = his_eef[-1][0] # (1, 8)
            current_obs = [v[-1] for v in video_dict_pred] 
            # forward policy
            policy_in_out, joint_pos, cartesian_pose= Agent.forward_policy(current_obs, current_pose, current_joint, text=text_i)
            print("cartesian space action", cartesian_pose[0]) # output xyz and gripper for debug
            print("cartesian space action", cartesian_pose[-1]) # output xyz and gripper for debug

            print("################ world model forward ################")
            # prepare input for world model
            print(f'task: {text_i}, traj_id: {val_id_i}, interact step: {i}/{interact_num}')
            # history_idx = [0,0,-12,-9,-6,-3]
            history_idx = args.history_idx
            action_cond = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)
            action_cond = np.concatenate([action_cond, cartesian_pose], axis=0) # (num_history+num_frames, 7)
            his_latent = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
            current_latent = his_cond[-1]  # (1, 4, 72, 40)
            # forward world model
            videos_cat, true_videos, video_dict_pred, predict_latents = Agent.forward_wm(action_cond, video_latent_true, current_latent, his_cond=his_latent,text=text_i if Agent.args.text_cond else None)
            
            print("################ record information ################")
            # push current step to history buffer
            his_joint.append(joint_pos[pred_step-1][None,:])  # (1, 8)
            his_eef.append(cartesian_pose[pred_step-1][None,:]) # (1, 7)
            his_cond.append(torch.cat([v[pred_step-1] for v in predict_latents], dim=1).unsqueeze(0))  # (1, 4, 72, 40)
            video_to_save.append(videos_cat[:pred_step-1])
            info_to_save.append(policy_in_out)  # save policy output info
            

        # save rollout video and info with parameters
        print("##########################################################################")
        video = np.concatenate(video_to_save, axis=0)
        text_id = text_i.replace(' ', '_').replace(',', '').replace('.', '').replace('\'', '').replace('\"', '')[:40]
        uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename_video = f"{args.save_dir}/{args.task_name}/video/{args.task_type}_time_{uuid}_traj_{val_id_i}_{start_idx_i}_{args.policy_skip_step}_{text_id}.mp4"
        os.makedirs(os.path.dirname(filename_video), exist_ok=True)
        mediapy.write_video(filename_video, video, fps=4)
        print(f"Saving video to {filename_video}")
        info = {'success': 1, 'start_idx': 0, 'end_idx': video.shape[0]-1, 'instructions':text_i}
        for key in info_to_save[0].keys():
            info[key] = []
            for i in range(len(info_to_save)):
                info[key]+=info_to_save[i][key].tolist()
        # save to json
        filename_info = f"{args.save_dir}/{args.task_name}/info/{args.task_type}_time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{text_id}.json"
        os.makedirs(os.path.dirname(filename_info), exist_ok=True)
        with open(filename_info, 'w') as f:
            json.dump(info, f, indent=4)
        print(f"Saving trajectory info to {filename_info}")
        print("##########################################################################")


# CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 python rollout_interact_pi.py --task_type pickplace
        
        
# CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 python scripts/rollout_interact_pi.py --dataset_root_path dataset_example --dataset_meta_info_path dataset_meta_info --dataset_names droid_subset --task_type pickplace