import torch,torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T

import geometry,mlp_helpers

# Lambda helpers
ch_sec = lambda x: rearrange(x, "... c x y -> ... (x y) c")
ch_fst = lambda src, x=None: rearrange( src, "... (x y) c -> ... c x y", x=int(src.size(-2) ** (0.5)) if x is None else x)

n_timesteps=5
class TmpPolicy(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        import torchvision.models as models
        resnet = models.resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.action_pred = mlp_helpers.make_net([512,128,128,6*n_timesteps])

    def forward( self, model_input, track_idxs=None, out={}):
        global_feat = self.resnet(model_input["rgb_img"]).squeeze(-1).squeeze(-1)
        return { "action_state": self.action_pred(global_feat).tanh().unflatten(-1,(n_timesteps,6))*2 }
class PolicyModel(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.args = args

        import torchvision.models as models
        if args.rob_pointcloud_inp or args.cam_pointcloud_inp:
            self.point_mlp = mlp_helpers.make_net([6,64,128,256])
        else:
            resnet = models.resnet18(pretrained=True)
            self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.action_pred = mlp_helpers.make_net([512,128,64,6*n_timesteps])

    def forward( self, model_input, track_idxs=None, out={}):
        if self.args.rob_pointcloud_inp or self.args.cam_pointcloud_inp:
            pc = model_input["rob_pc" if self.args.rob_pointcloud_inp else "cam_pc"]#.permute(0,2,1)
            feat = self.point_mlp(pc)                     
            global_feat = torch.cat((feat.max(dim=1)[0],feat.mean(dim=1)),-1)
        else:
            imgname = "rgb_img" if not self.args.rob_rerender_inp else "pc_render_img"
            global_feat = self.resnet(model_input[imgname]).squeeze(-1).squeeze(-1)
        
        #return { "action_state": self.action_pred(global_feat).tanh()*2 }
        return { "action_state": self.action_pred(global_feat).tanh().unflatten(-1,(n_timesteps,6))*2 }
