import torch,torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
import matplotlib.pyplot as plt 
from torch.nn import functional as F
import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T
#from torchcubicspline import(natural_cubic_spline_coeffs, NaturalCubicSpline)

import mlp_helpers

# Lambda helpers
ch_sec = lambda x: rearrange(x, "... c x y -> ... (x y) c")
ch_fst = lambda src, x=None: rearrange( src, "... (x y) c -> ... c x y", x=int(src.size(-2) ** (0.5)) if x is None else x)

class MnistModel(nn.Module):
    def __init__(self, dataset, args=None):
        super().__init__()
        self.args = args

        print("loading model")
        latent_dim=64
        self.mlp_baseline = mlp_helpers.make_net([2,latent_dim,latent_dim,1])

        self.dataset_inp,self.dataset_output = patchify(dataset.x.cuda()),dataset.y.cuda()
        output_dim = 1 if len(dataset.x.shape)==1 else dataset.x.size(-1)
        # TODO Random projection initialization of embedding
        #self.data_embeddings_temp   =nn.Parameter(torch.randn(1,1)) # replace with channel dimension after unsqueezing it
        #self.inp_emb_temp   =nn.Parameter(torch.ones(1,2)) # replace with channel dimension after unsqueezing it
        hidden_dims=[2,32]
        with torch.no_grad():
            self.emb_temps =nn.ParameterList([nn.Parameter(torch.ones(1,y)) for y in hidden_dims])
            self.hiddens=nn.ParameterList([ nn.Parameter(nn.Linear(2,y).cuda()(dataset.x.cuda().float())) for y in hidden_dims])
        #self.data_embeddings_hidden=nn.Parameter(torch.randn(dataset.x.size(0),16)) # todo expand this into key,val, this just happens to be last layer
        #nn.init.xavier_uniform_(self.data_embeddings_hidden)
        #nn.embedding(dim) for dim in [], dataset.x,dataset.y
        print("Add more layers")

        #self.forward=self.forward_mlp_baseline
    def forward_vit_baseline( self, model_input, out={}): # vit baseline
        patches = patchify(model_input["x"])+self.pos_emb
        patches = self_attn(patches)
        y = self.out(patches).mean()
        return out | {  "y": self.mlp_baseline(model_input["x"]).sigmoid(), }
    def forward( self, model_input, out={}): # vit baseline

        inp_patches = patchify(model_input["x"]) + self.pos_emb
        data_inp_patches = data_inp_patches + self.pos_emb

        scores    = (sq_inv_dist(inp_patches*self.emb_temps[0],data_inp_patches[idxs]*self.emb_temps[0])).softmax(dim=-1)

        for hidden in self.hiddens:
            curr_patches  = (scores[...,None] * hidden[idxs][None]).sum(1) 
            attn_response_emb = attn(k=curr_patches[...,:.5],v=curr_patches[...,.5:])
            hidden_state = cat(curr_patches[...,.5:],attn_response_emb)
            scores    = (sq_inv_dist(hidden_state[:,None]*emb_temp,hidden[idxs][None]*emb_temp)).softmax(dim=-1)
        y = (self.dataset_output[idxs][:,None,None]*scores).mean(-1,keepdim=True) # just unflatten dim 
        return out | {  "y": y }

    def forward_( self, model_input, out={}): # our dataset-weighting model

        # TODO 's: add random projections above; exclude training dataset index (maybe not even necessary bc smoothness? might be annoying to implement or could just drop all the training indices in the current batch if easier; add attention for set input like image patches); 

        inp_x = model_input["x"].unsqueeze(-1) if len(model_input["x"].shape)==2 else model_input["x"] # todo unsqueeze in dataloader
        dataset_x = self.dataset_inp.unsqueeze(-1) if len(self.dataset_inp.shape)==1 else self.dataset_inp # todo unsqueeze in dataloader

        # exclude training idxs from dataset
        if torch.is_grad_enabled(): idxs=[i for i in list(range(len(dataset_x))) if i not in model_input["idx"].squeeze(-1).tolist()]
        else: idxs=[list(range(len(dataset_x)))]

        # forward pass through layers, todo add more layers before moving to higher dim input / attn
        sq_inv_dist   = lambda x,y:1/(1e-5+(x-y).norm(dim=-1))
        scores    = (sq_inv_dist(inp_x*self.emb_temps[0],dataset_x[idxs]*self.emb_temps[0])).softmax(dim=-1)
        
        for emb_temp,hidden in zip(self.emb_temps[1:],self.hiddens[1:]):
            hidden_state  = (scores[...,None] * hidden[idxs][None]).sum(1) 
            scores    = (sq_inv_dist(hidden_state[:,None]*emb_temp,hidden[idxs][None]*emb_temp)).softmax(dim=-1)
        y = (self.dataset_output[idxs]*scores).sum(-1,keepdim=True)

        return out | {  "y": y }
class MoonsModel(nn.Module):
    def __init__(self, dataset, args=None):
        super().__init__()
        self.args = args

        print("loading model")
        latent_dim=64
        self.mlp_baseline = mlp_helpers.make_net([2,latent_dim,latent_dim,1])

        self.dataset_inp,self.dataset_output = dataset.x.cuda(),dataset.y.cuda()
        output_dim = 1 if len(dataset.x.shape)==1 else dataset.x.size(-1)
        # TODO Random projection initialization of embedding
        #self.data_embeddings_temp   =nn.Parameter(torch.randn(1,1)) # replace with channel dimension after unsqueezing it
        #self.inp_emb_temp   =nn.Parameter(torch.ones(1,2)) # replace with channel dimension after unsqueezing it
        hidden_dims=[2,32]
        with torch.no_grad():
            self.emb_temps =nn.ParameterList([nn.Parameter(torch.ones(1,y)) for y in hidden_dims])
            self.hiddens=nn.ParameterList([ nn.Parameter(nn.Linear(2,y).cuda()(dataset.x.cuda().float())) for y in hidden_dims])
        #self.data_embeddings_hidden=nn.Parameter(torch.randn(dataset.x.size(0),16)) # todo expand this into key,val, this just happens to be last layer
        #nn.init.xavier_uniform_(self.data_embeddings_hidden)
        #nn.embedding(dim) for dim in [], dataset.x,dataset.y
        print("Add more layers")

        #self.forward=self.forward_mlp_baseline
    def forward_mlp_baseline( self, model_input, out={}): # mlp baseline
        return out | {  "y": self.mlp_baseline(model_input["x"]).sigmoid(), }
    def forward( self, model_input, out={}): # our dataset-weighting model

        # TODO 's: add random projections above; exclude training dataset index (maybe not even necessary bc smoothness? might be annoying to implement or could just drop all the training indices in the current batch if easier; add attention for set input like image patches); 

        inp_x = model_input["x"].unsqueeze(-1) if len(model_input["x"].shape)==2 else model_input["x"] # todo unsqueeze in dataloader
        dataset_x = self.dataset_inp.unsqueeze(-1) if len(self.dataset_inp.shape)==1 else self.dataset_inp # todo unsqueeze in dataloader

        # exclude training idxs from dataset
        if torch.is_grad_enabled(): idxs=[i for i in list(range(len(dataset_x))) if i not in model_input["idx"].squeeze(-1).tolist()]
        else: idxs=[list(range(len(dataset_x)))]

        # forward pass through layers, todo add more layers before moving to higher dim input / attn
        sq_inv_dist   = lambda x,y:1/(1e-5+(x-y).norm(dim=-1))
        scores    = (sq_inv_dist(inp_x*self.emb_temps[0],dataset_x[idxs]*self.emb_temps[0])).softmax(dim=-1)
        
        for emb_temp,hidden in zip(self.emb_temps[1:],self.hiddens[1:]):
            hidden_state  = (scores[...,None] * hidden[idxs][None]).sum(1) 
            scores    = (sq_inv_dist(hidden_state[:,None]*emb_temp,hidden[idxs][None]*emb_temp)).softmax(dim=-1)
        y = (self.dataset_output[idxs]*scores).sum(-1,keepdim=True)

        return out | {  "y": y }
class LinRegressionModel(nn.Module):
    def __init__(self, dataset, args=None):
        super().__init__()
        self.args = args

        print("loading model")
        latent_dim=64
        self.mlp_baseline = mlp_helpers.make_net([1,latent_dim,latent_dim,latent_dim,1])

        self.dataset_inp,self.dataset_output = dataset.x.cuda(),dataset.y.cuda()
        output_dim = 1 if len(dataset.x.shape)==1 else dataset.x.size(-1)
        # TODO Random projection initialization of embedding
        self.data_embeddings_temp   =nn.Parameter(torch.randn(1,output_dim)) # replace with channel dimension after unsqueezing it
        self.data_embeddings_hidden=nn.Parameter(torch.randn(dataset.x.size(0),16)) # todo expand this into key,val, this just happens to be last layer
        nn.init.xavier_uniform_(self.data_embeddings_hidden)
        #nn.embedding(dim) for dim in [], dataset.x,dataset.y

        #self.forward=self.forward_mlp_baseline
    def forward_mlp_baseline( self, model_input, out={}): # mlp baseline
        return out | {  "y": self.mlp_baseline(model_input["x"]), }
    def forward( self, model_input, out={}): # our dataset-weighting model

        # TODO 's: add random projections above; exclude training dataset index (maybe not even necessary bc smoothness? might be annoying to implement or could just drop all the training indices in the current batch if easier; add attention for set input like image patches); 

        inp_x = model_input["x"].unsqueeze(-1) if len(model_input["x"].shape)==2 else model_input["x"] # todo unsqueeze in dataloader
        dataset_x = self.dataset_inp.unsqueeze(-1) if len(self.dataset_inp.shape)==1 else self.dataset_inp # todo unsqueeze in dataloader

        sq_inv_dist   = lambda x,y:1/(1e-5+(x-y).norm(dim=-1))
        inp_scores    = (sq_inv_dist(inp_x,dataset_x)*(self.data_embeddings_temp.abs()/50)).softmax(dim=-1)
        hidden_state  = (inp_scores[...,None] * self.data_embeddings_hidden[None]).sum(1) 
        hidden_scores = sq_inv_dist(hidden_state[:,None],self.data_embeddings_hidden[None]).softmax(dim=-1)
        y = (self.dataset_output.cuda()*hidden_scores).sum(-1,keepdim=True)

        return out | {  "y": y }
