import torch,torchvision
from torch import nn
import kornia
import functools
from einops import rearrange, repeat
#import torchvision.ops as ops
from torch.nn import functional as F
import numpy as np
import sys, random, time, os
from copy import deepcopy
from matplotlib import cm
import wandb
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Generator, Dict
from collections import defaultdict
#import torchvision.transforms as T

def make_net(dims):
    def init_weights_normal(m):
        if type(m) == nn.Linear:
            if hasattr(m, "weight"): nn.init.kaiming_normal_( m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
    layers = []
    for i in range(len(dims) - 1):
        layers.append(nn.Linear(dims[i], dims[i + 1]))
        layers.append(nn.ReLU())
    net = nn.Sequential(*layers[:-1])
    net.apply(init_weights_normal)
    return net

class PositionalEncodingNoFreqFactor(nn.Module):
    """PositionalEncoding module

    Maps v to positional encoding representation phi(v)

    Arguments:
        i_dim (int): input dimension for v
        N_freqs (int): #frequency to sample (default: 10)
    """
    def __init__(
            self,
            i_dim: int,
            N_freqs: int = 10,
    ) -> None:

        super().__init__()

        self.i_dim = i_dim
        self.out_dim = i_dim + (2 * N_freqs) * i_dim
        self.N_freqs = N_freqs

        a, b = 1, self.N_freqs - 1
        self.freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)

    def forward(self, v,start_with_v=True):
        pe = [v] if start_with_v else []
        for freq in self.freq_bands:
            fv = freq * v
            pe += [torch.sin(fv), torch.cos(fv)]
        return torch.cat(pe, dim=-1)


from torch import einsum
class CrossAttn_(nn.Module):
    def __init__(self, ch, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.ch = ch

        self.to_q = nn.Linear(ch, inner_dim, bias=False)
        self.to_kv = nn.Linear(ch, inner_dim * 2, bias=False)
        self.proj = nn.Linear(inner_dim, ch)

        self.out = nn.Sequential( nn.Linear(ch, int(4*ch)), nn.GELU(), nn.Linear(int(4*ch), ch))

        self.ln_1 = nn.LayerNorm([self.ch])
        self.ln_2 = nn.LayerNorm([self.ch])

    # x is the image patches and y is the cls tokens, for ex.
    def forward(self, x, y, attn_mask=None,return_heads=False, return_attn=False,softmax_axis=-1):

        x_ln = self.ln_1(x)
        y_ln = self.ln_1(y)

        h = self.heads

        q = self.to_q(y_ln)
        k, v = self.to_kv(x_ln).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        # attention, what we cannot get enough of
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim=softmax_axis)

        out = einsum('b i j, b j d -> b i d', attn, v)

        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

        out = self.proj(out) + y
        out = self.out(self.ln_2(out)) + out

        return out


