# Borrowed from https://github.com/openai/guided-diffusion
from abc import abstractmethod

import math
import numpy as np
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import kornia
from einops import rearrange,repeat
from torch import einsum
import diffusion 

use_cond = True

class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def normalization(channels):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(th.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with th.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with th.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = th.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads


class AttentionPool2d(nn.Module):
    """
    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
    """

    def __init__(
        self,
        spacial_dim: int,
        embed_dim: int,
        num_heads_channels: int,
        output_dim: int = None,
    ):
        super().__init__()
        self.positional_embedding = nn.Parameter(
            th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
        )
        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
        self.num_heads = embed_dim // num_heads_channels
        self.attention = QKVAttention(self.num_heads)

    def forward(self, x):
        b, c, *_spatial = x.shape
        x = x.reshape(b, c, -1)  # NC(HW)
        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
        x = self.qkv_proj(x)
        x = self.attention(x)
        x = self.c_proj(x)
        return x[:, :, 0]


class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb, scene_info):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, scene_info):
        if scene_info is None: scene_info=torch.zeros_like(torch.cat((emb[:,:0],emb[:,:0]),-1))
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb, scene_info[:,:0] if not use_cond else scene_info)
            elif layer._get_name()=='AttentionBlock':
                x,scene_info=layer(x,scene_info[:,:0] if not use_cond else scene_info)
                ## NOTE redo, we should be updating the scene info here (not the underscore)
            else:
                x = layer(x)
        return x,scene_info if scene_info.size(-1)==256 else scene_info[...,:scene_info.size(-1)//2] # todo refactor


class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        if use_conv:
            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            out = F.interpolate( x, (x.shape[2] * 2, x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
            #out = F.interpolate( x, (x.shape[2] * 1, x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
            # change back to double temporal res when going back to video data
        else:
            out = F.interpolate(x, scale_factor=2, mode="nearest")
        if x.shape[-1] == x.shape[-2] == 3:
            # upsampling layer transform [3x3] to [6x6]. Manually paddding it to make [7x7]
            out = F.pad(out, (1, 0, 1, 0))
        if self.use_conv:
            out = self.conv(out)
        return out


class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2, out_channels=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (2, 2, 2) # change first stride back to 2 if doing real video data 
        kernel_size = 3 if dims !=1 else (3,3,3) # change kernel size back to 3 if doing real video data  
        if use_conv:
            self.op = conv_nd(
                dims, self.channels, self.out_channels, kernel_size, stride=stride, padding=1
            )
        else:
            assert self.channels == self.out_channels
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class SimpleResMLP(nn.Module):
    def __init__(self, channels):
        super().__init__()
        dims=1
        self.in_layers = nn.Sequential(
            #normalization(channels),
            nn.SiLU(),
            nn.Linear(channels,channels),
        )
        self.out_layers = nn.Sequential(
            #normalization(channels),
            nn.SiLU(),
            nn.Linear(channels,channels),
        )
        self.skip_connection = nn.Linear(channels,channels)

    def forward(self, x):
        h = self.in_layers(x)
        h = self.out_layers(h)
        return self.skip_connection(x) + h

class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    :param up: if True, use this block for upsampling.
    :param down: if True, use this block for downsampling.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
        up=False,
        down=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False, dims)
            self.x_upd = Upsample(channels, False, dims)
        elif down:
            self.h_upd = Downsample(channels, False, dims)
            self.x_upd = Downsample(channels, False, dims)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,),
        )
        self.emb_layers_sc = nn.Sequential(
            nn.SiLU(),
            linear( channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,),
        )

        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

        self.scene_info_attn = SimpleCrossAttention(512,self.channels,1,256)
        self.spatial_lin = nn.Linear(2,self.channels)

    def forward(self, x, emb, scene_info=None):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.
        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb,scene_info), self.parameters(), self.use_checkpoint
        )

    def _forward(self, x, emb, scene_info=None):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            if x.size(2)==1 and 1: # only need this if doing 3d strided convs
                x=th.cat((x,x),2)
                h=th.cat((h,h),2)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        #emb = emb.permute(0,2,3,4,1)
        emb_out = self.emb_layers(emb).type(h.dtype)
        #emb_out = emb_out.permute(0,4,1,2,3)
        while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None]

        emb_out = emb_out.permute(0,2,1,3,4)
        emb_out = F.interpolate(emb_out,h.shape[-3:]) # video data

        # Scene info embedding attention
        if x.size(-1)<=16 and 0:
            uv_emb = self.spatial_lin(th.stack(th.meshgrid(th.linspace(-1,1,x.size(-2)),th.linspace(-1,1,x.size(-1))),-1).cuda())
            pos_emb_x = x + uv_emb.permute(2,0,1)[None,:,None]
            scene_emb = self.scene_info_attn(repeat(scene_info,"b s c -> b vxy s c",vxy=x.size(-1)*x.size(-2)*x.size(-3)),rearrange(pos_emb_x,"b c v x y -> b (v x y) 1 c"),use_skip=False)
            scene_emb = rearrange(self.emb_layers_sc(scene_emb),"b (v x y) 1 c -> b c v x y",x=x.size(-2),v=x.size(-3))
            emb_out = emb_out + scene_emb

        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        use_checkpoint=False,
        use_new_attention_order=False,
        channels_out=None,
    ):
        super().__init__()
        self.channels = channels
        self.channels_out = channels if channels_out is None else channels_out
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        if use_new_attention_order:
            # split qkv before split heads
            self.attention = QKVAttention(self.num_heads)
        else:
            # split heads before split qkv
            self.attention = QKVAttentionLegacy(self.num_heads)

        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

        self.scene_info_proj_in  = nn.Linear(512,self.channels)
        self.scene_info_proj_out = nn.Linear(self.channels,256)

        self.spatial_lin=nn.Linear(2,self.channels)

        self.mlp_in  = SimpleResMLP(self.channels)
        self.mlp_out = SimpleResMLP(self.channels)
        self.reg_pos_emb = nn.Linear(1,self.channels)

    def forward(self, x,scene_info):
        return checkpoint(self._forward, (x,scene_info), self.parameters(), True)

    # accepts just flattened list
    # accepts spatial input
    def _forward(self, x,scene_info):

        not_spatial=len(x.shape)==3
        if not_spatial: x=x.permute(0,2,1)[...,None,:,None]
        b, c, *spatial = x.shape
        using_x = x.size(1)!=0

        if using_x:
            uv_emb = self.spatial_lin(th.stack(th.meshgrid(th.linspace(-1,1,x.size(-2)),th.linspace(-1,1,x.size(-1))),-1).cuda())
            x = x + uv_emb.permute(2,0,1)[None,:,None]

            # Note this is doing attention over temporal dimension right now
            x = x.reshape(b, c, -1)

            scene_info = self.scene_info_proj_in(scene_info).permute(0,2,1)
            reg_pos_emb = self.reg_pos_emb(th.linspace(-1,1,scene_info.size(-1))[:,None].cuda())[None].permute(0,2,1)
            scene_info=scene_info+reg_pos_emb

            x = th.cat((x,scene_info),2)
        else: 
            x = self.scene_info_proj_in(scene_info).permute(0,2,1)
            reg_pos_emb = self.reg_pos_emb(th.linspace(-1,1,scene_info.size(-2))[:,None].cuda())[None].permute(0,2,1)
            x = x + reg_pos_emb

        x = self.mlp_in(x.permute(0,2,1)).permute(0,2,1) # mod

        qkv = self.qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)

        out = (x + h)

        out = self.mlp_out(out.permute(0,2,1)).permute(0,2,1) # mod

        if scene_info.size(-1)!=0: out,scene_info = out[...,:-scene_info.size(-1)], self.scene_info_proj_out(out[...,-scene_info.size(-1):].permute(0,2,1))

        out = out.reshape(b, c, *spatial)
        if not_spatial: out= out.squeeze(-1).squeeze(-2).permute(0,2,1)

        return out,scene_info


def count_flops_attn(model, _x, y):
    """
    A counter for the `thop` package to count the operations in an
    attention operation.
    Meant to be used like:
        macs, params = thop.profile(
            model,
            inputs=(inputs, timestamps),
            custom_ops={QKVAttention: QKVAttention.count_flops},
        )
    """
    b, c, *spatial = y[0].shape
    num_spatial = int(np.prod(spatial))
    # We perform two matmuls with the same number of ops.
    # The first computes the weight matrix, the second computes
    # the combination of the value vectors.
    matmul_ops = 2 * b * (num_spatial ** 2) * c
    model.total_ops += th.DoubleTensor([matmul_ops])


class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)

class SimpleCrossAttention(nn.Module):
    """
    Simple vanilla cross attention
    """
    def __init__(self, ch_kv, ch_q, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads
        #self.ch = ch

        self.to_q = nn.Linear(ch_q, inner_dim, bias=False)
        self.to_kv = nn.Linear(ch_kv, inner_dim * 2, bias=False)
        self.proj = nn.Linear(inner_dim, ch_q)

        self.out = nn.Sequential(
            nn.Linear(ch_q, int(4*ch_q)),
            nn.GELU(),
            nn.Linear(int(4*ch_q), ch_q)
        )

        self.ln_1 = nn.LayerNorm([ch_kv])
        self.ln_2 = nn.LayerNorm([ch_q])

    # x is the image patches and y is the cls tokens, for ex., or 
    # x is the kv, y is the q
    def forward(self, x, y, softmax_axis=-1,use_skip=True):
        if len(x.shape)>3: 
            return self(x.flatten(0,-3),y.flatten(0,-3), softmax_axis).unflatten(0,x.shape[:-2])

        x_ln = self.ln_1(x)
        y_ln = self.ln_2(y)

        h = self.heads

        q = self.to_q(y_ln)
        k, v = self.to_kv(x_ln).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        # attention, what we cannot get enough of
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim=softmax_axis)

        out = einsum('b i j, b j d -> b i d', attn, v)

        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

        out = self.proj(out) 
        if use_skip: out = out + y
        out = self.out(self.ln_2(out)) + out

        return out

class QKVAttention(nn.Module):
    """
    A module which performs QKV attention and splits in a different order.
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)



class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param emb_dim: base dimension of timestep embedding.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        time_emb_factor=4,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=3,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
    ):
        super().__init__()

        in_channels,out_channels=3,8 # hardcoded; todo refactor

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = 256#model_channels * time_emb_factor
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )
        self.frame_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        #self.raymap_emb = linear(6, time_embed_dim)
        self.cam_emb = linear(6, time_embed_dim) 
        self.cam_decode = nn.Sequential(
            linear(time_embed_dim, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, 6),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)

        ch = input_ch = int(channel_mult[0] * model_channels)
        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
        )
        self._feature_size = ch
        input_block_chans = [ch]
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock( ch, time_embed_dim, dropout, out_channels=int(mult * model_channels), dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm,)
                ]
                ch = int(mult * model_channels)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order,)
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True)
                        if resblock_updown
                        else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        self.middle_block = TimestepEmbedSequential(
            ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm,),
            AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order,),
            ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm,),
        )
        self._feature_size += ch

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                layers = [
                    ResBlock( ch + ich, time_embed_dim, dropout, out_channels=int(model_channels * mult), dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm,)
                ]
                ch = int(model_channels * mult)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order,)
                    )
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True)
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),)

        self.cameras_proj_in = linear(6, 256)
        self.cameras_proj_out = linear(256,6*2)
        self.scene_graph_proj_in = linear(11, 256)
        self.scene_graph_proj_out = linear(256,11*2)
        self.scene_emb_proj = linear(256,256)
        #self.time_emb_proj = linear(256, 64)

        self.middle_attns = nn.ModuleList([
                        #nn.Sequential(
                            #SimpleResMLP(ch),
                            AttentionBlock( ch*2,channels_out=ch, use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order,)
                            #SimpleResMLP(ch),
                        #    )
                         for _ in range(10)])
        self.sg_attn_emb_projs_global = nn.ModuleList([zero_module(nn.Linear(128,128)) for _ in self.middle_attns])
        self.sg_attn_emb_projs_spatial = nn.ModuleList([zero_module( conv_nd(3, 128, 128, 3, padding=1)) for _ in self.middle_attns])
        self.img_to_emb = zero_module(nn.Linear(128,256))

        self.scene_info_embed = nn.Sequential(*([SimpleResMLP(512)]+[nn.Linear(512,256)]))
        self.bbox_embed = nn.Sequential(*([nn.Linear(6,256)]+[SimpleResMLP(256)]+[nn.Linear(256,256)]))
        self.clip_embed = nn.Sequential(*([SimpleResMLP(512)]+[nn.Linear(512,256)]))
        self.global_clip_emb = nn.Linear(256,256)
        self.global_bbox_emb= nn.Linear(256,256)

        self.non_embedding = nn.Embedding(1, 256).cuda()

        # Autodecoder testing
        #self.scene_codes = nn.Embedding(100000, 256).cuda()
        #nn.init.normal_(self.scene_codes.weight, mean=0, std=0.01)

        self.img_tok_downproj = nn.Conv3d(192,128,1,1)

        self.fwd_diffuser = diffusion.GaussianDiffusion()

    def denoise_scene_graph(self,sg_input,timesteps):

        scene_graph = self.scene_graph_proj_in(sg_input["noised_scene_graph"])
        camera_info = self.cameras_proj_in(sg_input["noised_cameras"])
        org_scene_info = scene_info = scene_graph + camera_info 
        emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None]#[:,:,None,None].permute(0,4,1,2,3) 
        """
        if "img_tok" in sg_input: 
            img_tok_cond=sg_input["img_tok"]
            emb = emb + self.sg_attn_emb_projs_global[0](img_tok_cond.flatten(2,4).permute(0,2,1)).sum(1,keepdim=True)

        for attn, global_img_emb_proj, local_img_emb_proj in zip(self.middle_attns,self.sg_attn_emb_projs_global,self.sg_attn_emb_projs_spatial): 
            scene_info = th.cat((emb,scene_info),1) 
            scene_info = th.cat((scene_info, th.cat((th.zeros_like(org_scene_info[:,:1]),org_scene_info),1) ),-1) # concat scene graph input at each layer as pseudo-skip skip connect
            scene_info = attn(th.zeros(0,0,0,0,0).to(scene_info),scene_info)[1]  # dummy input for spatial tokens (only doing attention with registers/scenegraph here) 
            scene_info = scene_info[:,1:] # remove emb from list of scene info

        """

        if "img_tok" in sg_input: 
            img_tok_cond=sg_input["img_tok"]
            emb = emb + self.img_to_emb(img_tok_cond.flatten(2,4).permute(0,2,1)).sum(1,keepdim=True)

        for attn, global_img_emb_proj, local_img_emb_proj in zip(self.middle_attns,self.sg_attn_emb_projs_global,self.sg_attn_emb_projs_spatial): 
            scene_info = th.cat((emb,scene_info),1) 
            scene_info = th.cat((scene_info, th.cat((th.zeros_like(org_scene_info[:,:1]),org_scene_info),1) ),-1) # concat scene graph input at each layer as pseudo-skip skip connect
            if "img_tok" in sg_input and 0:
                img_tok_cond,scene_info = attn(local_img_emb_proj(img_tok_cond),scene_info)
                # NOTE #print("replace with img sum")
            else: 
                scene_info = attn(th.zeros(0,0,0,0,0).to(scene_info),scene_info)[1]  # dummy input for spatial tokens (only doing attention with registers/scenegraph here) 
            scene_info = scene_info[:,1:] # remove emb from list of scene info
        cameras_pred = self.cameras_proj_out( scene_info[:,-sg_input["noised_cameras"].size(1):] )
        scene_graph_pred = self.scene_graph_proj_out( scene_info )

        # TODO NOTE!! print("todo -- want to use same exact input/output sg projections, use scene graph / camera bottleneck (not abstract latent) so that clean inp/pred is same")

        # Export scene graph-specific predictions
        out_dict = {"eps_scene_graph":scene_graph_pred[...,:scene_graph_pred.size(-1)//2],"scene_graph":scene_graph_pred[...,scene_graph_pred.size(-1)//2:],
                    "eps_cameras":cameras_pred[...,:cameras_pred.size(-1)//2],"cameras":cameras_pred[...,cameras_pred.size(-1)//2:]}
        return out_dict,(scene_graph_pred,cameras_pred)

    # Scene graph
    def forward(self, model_input, timesteps, y=None, autodecoder=False, use_skip=False, clip_global_latent=False, sample_rand_global=False, use_emb=False, teacher_forcing=True):
        use_emb=True
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        out_dict = {}

        use_encoder,use_decoder=True,True#True,True

        # Create scene graph from img

        #if use_encoder or use_decoder:
        #x=model_input["noised_rgb"] if "noised_rgb" in model_input else th.zeros(1,1,1,1,1) # replace with zeros of shape

        if "noised_rgb" not in model_input:model_input["noised_rgb"]=torch.zeros_like(model_input["rgb"])

        # # NOTE!!! using clean rgb for conditional scene graph testing
        if use_encoder or use_decoder:
            x=model_input["noised_rgb"]
            x=x.permute(0,2,1,3,4)

        if not use_encoder and use_decoder:
            x=th.zeros_like(x)

        hs = []
        emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None]#.expand(-1,x.size(2),-1)#[:,:,None,None].permute(0,4,1,2,3) 

        # add scene graph to input emb
        bbox_embs = self.bbox_embed(model_input["bboxs"].flatten(-2,-1))
        clip_embs = self.clip_embed(model_input["clip_embs"])
        org_scene_info = scene_info = bbox_embs + clip_embs

        emb_global = (self.global_bbox_emb(bbox_embs)+self.global_clip_emb(bbox_embs)).sum(1,keepdim=True)
        emb = emb + emb_global

        if use_encoder:
            h = x.type(self.dtype)
            for module in self.input_blocks:
                h,scene_info_ = module(h, emb, th.cat((scene_info, org_scene_info),-1))
                #h = module(h, emb, None)[0]#th.cat((scene_info, org_scene_info),-1))
                if use_skip: scene_info=scene_info_
                hs.append(h)
            h,scene_info_ = self.middle_block(h, emb, th.cat((scene_info, org_scene_info),-1))
            #h= self.middle_block(h, emb,None)[0]# th.cat((scene_info, org_scene_info),-1))
            if use_skip: scene_info=scene_info_
            if not use_skip: h=th.zeros_like(h) 

        # Middle scene graph attns

        # Forward-diffuse input scene graph
        #if 1:
        #    if use_encoder: img_tok_cond = self.img_tok_downproj(hs[-4]) 
        #    denoised_sg, (scene_in,cam_in) = self.denoise_scene_graph(model_input | ({"img_tok":img_tok_cond} if use_encoder and 1 else {}),timesteps)
        #    scene_in,cam_in=[x[...,x.size(-1)//2:] for x in (scene_in,cam_in)]
        #    out_dict |= denoised_sg
        #scene_in,cam_in=gen_sg["scene_graph"],gen_sg["cameras"]
        #clean_sanity=False
        #if clean_sanity and "scene_graph" in model_input:
        #    print("using sg as input")
        #    cam_in = model_input["cameras"]
        #    scene_in = model_input["scene_graph"]
        #scene_graph = self.scene_graph_proj_in(scene_in)
        #camera_info = self.cameras_proj_in(cam_in)
        #scene_latent = org_scene_info = scene_info = scene_graph + camera_info  # = th.cat((scene_graph,camera_info),1)
        #scene_info=org_scene_info=scene_latent
        #emb = emb + self.scene_emb_proj(scene_latent)

        if use_decoder:

            for module in self.output_blocks:
                hp = hs.pop()
                if h.size(2)!=hp.size(2): h=h[:,:,:hp.size(2)]
                if not use_skip: hp=th.zeros_like(hp)
                h = th.cat([h, hp], dim=1)
                if scene_info.size(-1)==0:scene_info=org_scene_info
                h,scene_info = module(h, emb,th.cat((scene_info, org_scene_info),-1))
            h = h.type(x.dtype)
            out = self.out(h)

            # format (todo factorize)
            out_dict |= {
                    "eps_rgb": out[:,[0,1,2]],
                    "rgb": out[:,[3,4,5]].tanh(),
                    #"seg": out[:,[6]].sigmoid(),
                    #"invdepth": out[:,[7]].sigmoid(),
            }
        return {k:(v if len(v.shape)!=5 else v.permute(0,2,1,3,4)) for k,v in out_dict.items()}
    def forward_(self, model_input, timesteps, y=None, autodecoder=False, use_skip=False, clip_global_latent=False, sample_rand_global=False, use_emb=False, teacher_forcing=True):
        use_emb=True
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        out_dict = {}

        use_encoder,use_decoder=True,True#True,True

        #if use_encoder or use_decoder:
        #x=model_input["noised_rgb"] if "noised_rgb" in model_input else th.zeros(1,1,1,1,1) # replace with zeros of shape

        # # NOTE!!! using clean rgb for conditional scene graph testing
        if use_encoder or use_decoder:
            x=model_input["noised_rgb"]
            x=x.permute(0,2,1,3,4)

        if not use_encoder and use_decoder:
            x=th.zeros_like(x)

        hs = []
        #emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None,None].expand(-1,model_input["noised_rgb"].size(1),-1).flatten(0,1)
        #emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None].expand(-1,x.size(2),-1)#[:,:,None,None].permute(0,4,1,2,3) 
        emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None]#.expand(-1,x.size(2),-1)#[:,:,None,None].permute(0,4,1,2,3) 
        #emb_frame =  self.frame_embed(timestep_embedding(th.arange(x.size(2)).cuda(), self.model_channels))[None]
        #emb = emb_time + emb_frame
        # add frame emb?
        #emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        #scene_info = th.cat((self.time_emb_proj(emb_time[:,None]),self.scene_graph_proj_in(model_input["noised_scene_graph"])),1)
        #scene_info = self.scene_graph_proj_in(model_input["noised_scene_graph"])
        
        # TODO make tiny skip/residual mlp class to inverleave with attentions
        #for attn in self.middle_attns: scene_info=attn(scene_info)
        #scene_info = self.scene_graph_proj_out( scene_info[:,1:] )

        #out_dict = {"eps_scene_graph":scene_info[...,:scene_info.size(-1)//2],"scene_graph":scene_info[...,scene_info.size(-1)//2:]}
        #return out_dict 

        if use_encoder:
            h = x.type(self.dtype)
            for module in self.input_blocks:
                #h,scene_info_ = module(h, emb, th.cat((scene_info, org_scene_info),-1))
                h = module(h, emb, None)[0]#th.cat((scene_info, org_scene_info),-1))
                #if use_skip: scene_info=scene_info_
                hs.append(h)
            #h,scene_info_ = self.middle_block(h, emb, th.cat((scene_info, org_scene_info),-1))
            h= self.middle_block(h, emb,None)[0]# th.cat((scene_info, org_scene_info),-1))
            #if use_skip: scene_info=scene_info_
            if not use_skip: h=th.zeros_like(h) 

        # Middle scene graph attns

        #clean_sanity=False
        #cam_in = model_input["noised_cameras" if not clean_sanity else "cameras"]# if "noised_cameras" in model_input else th.zeros(len(x),1,6).to(x)
        #scene_in = model_input["noised_scene_graph" if not clean_sanity else "scene_graph"]# if "scene_graph" in model_input else th.zeros(len(x),6,11).to(x)
        #scene_graph = self.scene_graph_proj_in(scene_in)
        #camera_info = self.cameras_proj_in(cam_in)
        #org_scene_info = scene_info = scene_graph + camera_info  # = th.cat((scene_graph,camera_info),1)
        #if not use_emb: scene_info = org_scene_info = self.non_embedding(th.tensor([0]).cuda())[None].expand(len(emb),-1,-1) 

        if use_encoder: img_tok_cond = self.img_tok_downproj(hs[-4]) 

        # Forward-diffuse input scene graph
        if 1:
            #xT = {"rgb":model_input["rgb"],"scene_graph":model_input["scene_graph"],"cameras":model_input["cameras"]}
            xT = {k:v for k,v in model_input.items() if "noised" in k}
            if "noised_scene_graph" not in xT:
                xT |= {"noised_scene_graph":torch.randn(len(x), 6,11).float().cuda(),"noised_cameras":torch.randn(len(x), 1,6).float().cuda()}
            #xT = {"noised_"+k: torch.randn(len(x), *v.shape[1:]) .float() .to(v.device) for k,v in xT.items()}
            gen_sg,gen_sg_intermeds,latent_eps_loss, (scene_in,cam_in)= self.fwd_diffuser.diff_sample_from_reverse_process(self, xT, 5, {"y": None}, 
                        False,scene_graph=True,cameras=xT["noised_cameras"],conditioning={"img_tok":img_tok_cond} if use_encoder and 1 else {},
                        use_direct=False, model_input=model_input,teacher_forcing=teacher_forcing)
            out_dict |= gen_sg | {k+"_intermed":v for k,v in gen_sg_intermeds.items()} | latent_eps_loss
            scene_in,cam_in=[x[...,x.size(-1)//2:] for x in (scene_in,cam_in)]
        elif 1:
            denoised_sg, (scene_in,cam_in) = self.denoise_scene_graph(model_input | ({"img_tok":img_tok_cond} if use_encoder and 1 else {}),timesteps)
            scene_in,cam_in=[x[...,x.size(-1)//2:] for x in (scene_in,cam_in)]
            out_dict |= denoised_sg

        scene_in,cam_in=gen_sg["scene_graph"],gen_sg["cameras"]

        clean_sanity=False
        if clean_sanity and "scene_graph" in model_input:
            print("using sg as input")
            cam_in = model_input["cameras"]
            scene_in = model_input["scene_graph"]
        scene_graph = self.scene_graph_proj_in(scene_in)
        camera_info = self.cameras_proj_in(cam_in)
        scene_latent = org_scene_info = scene_info = scene_graph + camera_info  # = th.cat((scene_graph,camera_info),1)

        #if use_encoder: scene_info = scene_info + self.sg_attn_emb_projs_spatial[0](img_tok_cond).sum(dim=[2,3,4])[:,None]

        # create diffusion scene info latents

        #for attn, global_img_emb_proj, local_img_emb_proj in zip(self.middle_attns,self.sg_attn_emb_projs_global,self.sg_attn_emb_projs_spatial): 
        #    #if use_encoder: scene_info = scene_info + global_img_emb_proj(img_tok_cond.sum(dim=[2,3,4])[:,None])
        #    scene_info = th.cat((emb,scene_info),1) # add emb to list of scene info; note we should redo this with the same film-style embedding conditioning instead of concatenating emb. TODO
        #    scene_info = th.cat((scene_info, th.cat((th.zeros_like(org_scene_info[:,:1]),org_scene_info),1) ),-1) # concat scene graph input at each layer as pseudo-skip skip connect
        #    if use_encoder and 1:
        #        #scene_info = attn(local_img_emb_proj(img_tok_cond),scene_info)[1]
        #        img_tok_cond,scene_info = attn(local_img_emb_proj(img_tok_cond),scene_info)
        #    else:
        #        scene_info = attn(th.zeros(0,0,0,0,0).to(scene_info),scene_info)[1]  # dummy input for spatial tokens (only doing attention with registers/scenegraph here) 
        #    scene_info = scene_info[:,1:] # remove emb from list of scene info

        #if use_cond:
        #    scene_emb = self.scene_info_embed(th.cat((scene_info, org_scene_info),-1)).sum(1,keepdim=True)
        #    emb = emb + scene_emb

        ## Return bottleneck prediction if used
        #if "noised_cameras" in model_input:
        #    cameras_pred = self.cameras_proj_out( scene_info[:,-model_input["noised_cameras"].size(1):] )
        #    #scene_graph_pred = self.scene_graph_proj_out( scene_info[:,:-model_input["noised_cameras"].size(1)] )
        #    scene_graph_pred = self.scene_graph_proj_out( scene_info )

        #    # Export scene graph-specific predictions
        #    out_dict |= {"eps_scene_graph":scene_graph_pred[...,:scene_graph_pred.size(-1)//2],"scene_graph":scene_graph_pred[...,scene_graph_pred.size(-1)//2:]}
        #    out_dict |= {"eps_cameras":cameras_pred[...,:cameras_pred.size(-1)//2],"cameras":cameras_pred[...,cameras_pred.size(-1)//2:]}
        #    #return out_dict

        #    # NOTE testing with explicit scene graph bottleneck here
        #    #scene_graph_pred = self.scene_graph_proj_in(out_dict["scene_graph"])
        #    #camera_info_pred = self.cameras_proj_in(out_dict["cameras"])
        #    #scene_info = th.cat((scene_graph_pred,camera_info_pred),1)
        #else: scene_info = scene_info

        scene_info=org_scene_info=scene_latent

        emb = emb + self.scene_emb_proj(scene_latent)

        if use_decoder:

            for module in self.output_blocks:
                hp = hs.pop()
                if h.size(2)!=hp.size(2): h=h[:,:,:hp.size(2)]
                if not use_skip: hp=th.zeros_like(hp)
                h = th.cat([h, hp], dim=1)
                if scene_info.size(-1)==0:scene_info=org_scene_info
                h,scene_info = module(h, emb,th.cat((scene_info, org_scene_info),-1))
            h = h.type(x.dtype)
            out = self.out(h)

            # format (todo factorize)
            out_dict |= {
                    "eps_rgb": out[:,[0,1,2]],
                    "rgb": out[:,[3,4,5]].tanh(),
                    "seg": out[:,[6]].sigmoid(),
                    "invdepth": out[:,[7]].sigmoid(),
            }
        return {k:(v if len(v.shape)!=5 else v.permute(0,2,1,3,4)) for k,v in out_dict.items()}

    def forward_(self, model_input, timesteps, y=None, autodecoder=False, use_skip=False, clip_global_latent=False, sample_rand_global=False):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """

        # current experiment: just using autodecoder (skipping encoder entirely) for now

        #x = model_input["noised_rgb"].permute(0,2,1,3,4)
        #if model_input["noised_rgb"].size(1)>1: x = th.cat((model_input["noised_rgb"][:,:1],th.zeros_like(model_input["noised_rgb"][:,1:-1]),model_input["noised_rgb"][:,-1:]),1)
        #else: 

        #print("custom time")
        #timesteps=timesteps*0+118
        #print(timesteps)

        #autodecoder,use_skip=False,False

        x=model_input["noised_rgb"]
        x=x.permute(0,2,1,3,4)

        if autodecoder: x=th.zeros_like(x)

        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        if "cameras" not in model_input:
            model_input["cameras"]=th.eye(4).cuda()[None,None].expand(*model_input["noised_rgb"].shape[:2],-1,-1)

        hs = []
        #emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None,None].expand(-1,model_input["noised_rgb"].size(1),-1).flatten(0,1)
        emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None].expand(-1,x.size(2),-1)#[:,:,None,None].permute(0,4,1,2,3) 
        # Autodecoder embedding
        #if autodecoder: emb = emb + self.scene_codes(model_input["idx"]).squeeze(-2)
        #print("interpolation testing")
        #alpha=th.linspace(0,1,10).cuda()
        #scene_embs=self.scene_codes(th.tensor([0,1]).cuda()[:,None,None].expand(-1,model_input["idx"].size(1),1)).squeeze(-2)
        #model_input["cameras"][:,:]=model_input["cameras"][[1]].expand(model_input["cameras"].size(0),model_input["cameras"].size(1),-1,-1)[:,:]
        #emb = emb + alpha[:,None,None]*scene_embs[[0]]+(1-alpha[:,None,None])*scene_embs[[1]]

        #emb = emb + self.raymap_emb(th.cat((model_input["raymap_origin"],model_input["raymap_dir"]),2).permute(0,1,3,4,2)).permute(0,4,1,2,3) # pose/raymap embedding

        # Camera embedding
        model_input["cams"] = th.cat((model_input["cameras"][...,:3,-1],kornia.geometry.conversions.rotation_matrix_to_axis_angle(model_input["cameras"][...,:3,:3])),-1)
        #emb = emb + self.cam_emb(model_input["cams"])

        emb_frame =  self.frame_embed(timestep_embedding(th.arange(x.size(2)).cuda(), self.model_channels))[None]
        emb = emb_time + emb_frame
        # add frame emb?

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)

        # Global bottleneck
        global_latent = h.flatten(2,-1).mean(dim=-1)[:,None]
        if not use_skip: h=th.zeros_like(h) # modification

        if sample_rand_global: global_latent = (th.rand_like(global_latent)*2-1)/10
        if clip_global_latent: global_latent = global_latent.clip(-.1,.1)
        #global_latent=(th.rand_like(global_latent)*2-1)/10
        #print("interpolation testing")
        #alpha=th.linspace(0,1,len(model_input["rgb"])).cuda()
        #global_latent = alpha[:,None,None]*global_latent[[1]]+(1-alpha[:,None,None])*global_latent[[0]]

        emb = emb + global_latent

        for module in self.output_blocks:
            hp = hs.pop()
            if h.size(2)!=hp.size(2): h=h[:,:,:hp.size(2)]
            if not use_skip: hp=th.zeros_like(hp)
            h = th.cat([h, hp], dim=1)
            h = module(h, emb)
        h = h.type(x.dtype)
        out = self.out(h)

        # Camera decoding
        cam_pred=self.cam_decode(global_latent + emb_frame)

        # format (todo factorize)
        out_dict = {
                "eps": out[:,[0,1,2]],
                "rgb": out[:,[3,4,5]].tanh(),
                "seg": out[:,[6]].sigmoid(),
                "invdepth": out[:,[7]].sigmoid(),
                "global_latent":global_latent,
                "cams":cam_pred,
        }
        return {k:(v if len(v.shape)!=5 else v.permute(0,2,1,3,4)) for k,v in out_dict.items()}

    def forward_full(self, model_input, timesteps, y=None):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """

        # current experiment: just using autodecoder (skipping encoder entirely) for now

        #x = model_input["noised_rgb"].permute(0,2,1,3,4)
        #if model_input["noised_rgb"].size(1)>1: x = th.cat((model_input["noised_rgb"][:,:1],th.zeros_like(model_input["noised_rgb"][:,1:-1]),model_input["noised_rgb"][:,-1:]),1)
        #else: 
        x=model_input["noised_rgb"]
        x=x.permute(0,2,1,3,4)

        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        #emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None,None].expand(-1,model_input["noised_rgb"].size(1),-1).flatten(0,1)
        emb = emb_time = self.time_embed(timestep_embedding(timesteps, self.model_channels))[:,None].expand(-1,x.size(2),-1)#[:,:,None,None].permute(0,4,1,2,3)
        if "cameras" not in model_input: model_input["cameras"]=th.eye(4).cuda()[None,None].expand(*model_input["noised_rgb"].shape[:2],-1,-1)
        emb = emb + self.cam_emb(th.cat((model_input["cameras"][...,:3,-1],kornia.geometry.conversions.rotation_matrix_to_axis_angle(model_input["cameras"][...,:3,:3])),-1))
        #emb = emb + self.raymap_emb(th.cat((model_input["raymap_origin"],model_input["raymap_dir"]),2).permute(0,1,3,4,2)).permute(0,4,1,2,3) # pose/raymap embedding

        #emb_frame =  self.frame_embed(timestep_embedding(th.arange(x.size(2)).cuda(), self.model_channels))
        #emb = emb_time + emb_frame[None]
        # add frame emb?

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)
        for module in self.output_blocks:
            hp = hs.pop()
            if h.size(2)!=hp.size(2): h=h[:,:,:hp.size(2)]
            h = th.cat([h, hp], dim=1)
            h = module(h, emb)
        h = h.type(x.dtype)
        out = self.out(h)

        # format (todo factorize)
        out_dict = {
                "eps": out[:,[0,1,2]],
                "rgb": out[:,[3,4,5]].tanh(),
                "seg": out[:,[6]].sigmoid(),
                "invdepth": out[:,[7]].sigmoid(),
        }
        return {k:v.permute(0,2,1,3,4) for k,v in out_dict.items()}
        #return {k:v.unflatten(0,model_input["noised_rgb"].shape[:2]) for k,v in out_dict.items()}

def UNetBig(
    image_size,
    in_channels=3,
    out_channels=3,
    base_width=192,
    num_classes=None,
):
    if image_size == 128:
        channel_mult = (1, 1, 2, 3, 4)
    elif image_size == 64:
        channel_mult = (1, 2, 3, 4)
    elif image_size == 32:
        channel_mult = (1, 2, 2, 2)
    elif image_size == 28:
        channel_mult = (1, 2, 2, 2)
    else:
        raise ValueError(f"unsupported image size: {image_size}")

    attention_ds = []
    if image_size == 28:
        attention_resolutions = "28,14,7"
    else:
        attention_resolutions = "32,16,8"
    for res in attention_resolutions.split(","):
        attention_ds.append(image_size // int(res))

    return UNetModel(
        image_size=image_size,
        in_channels=in_channels,
        out_channels=out_channels,
        num_res_blocks=3,
        model_channels=base_width,
        attention_resolutions=tuple(attention_ds),
        dropout=0.1,
        channel_mult=channel_mult,
        num_classes=num_classes,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=4,
        num_head_channels=64,
        num_heads_upsample=-1,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_new_attention_order=True,
    )


def UNet(
    image_size,
    in_channels=3,
    out_channels=3,
    base_width=64,
    num_classes=None,
):
    if image_size == 128:
        channel_mult = (1, 1, 2, 3, 4)
    elif image_size == 64:
        channel_mult = (1, 2, 3, 4)
    elif image_size == 32:
        channel_mult = (1, 2, 2, 2)
    elif image_size == 28:
        channel_mult = (1, 2, 2, 2)
    else:
        raise ValueError(f"unsupported image size: {image_size}")

    attention_ds = []
    if image_size == 28:
        attention_resolutions = "28,14,7"
    else:
        attention_resolutions = "32,16,8"
    for res in attention_resolutions.split(","):
        attention_ds.append(image_size // int(res))

    return UNetModel(
        image_size=image_size,
        in_channels=in_channels,
        model_channels=base_width,
        out_channels=out_channels,
        num_res_blocks=2,
        attention_resolutions=tuple(attention_ds),
        dropout=0.1,
        channel_mult=channel_mult,
        num_classes=num_classes,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=4,
        num_head_channels=64,
        num_heads_upsample=-1,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_new_attention_order=True,
    )


def UNetSmall(
    image_size,
    in_channels=3,
    out_channels=3,
    base_width=32,
    num_classes=None,
):
    if image_size == 128:
        channel_mult = (1, 1, 2, 3, 4)
    elif image_size == 64:
        channel_mult = (1, 2, 3, 4)
    elif image_size == 32:
        channel_mult = (1, 2, 2, 2)
    elif image_size == 28:
        channel_mult = (1, 2, 2, 2)
    else:
        raise ValueError(f"unsupported image size: {image_size}")

    attention_ds = []
    if image_size == 28:
        attention_resolutions = "28,14,7"
    else:
        attention_resolutions = "32,16,8"
    for res in attention_resolutions.split(","):
        attention_ds.append(image_size // int(res))

    return UNetModel(
        image_size=image_size,
        in_channels=in_channels,
        model_channels=base_width,
        out_channels=out_channels,
        num_res_blocks=2,
        attention_resolutions=tuple(attention_ds),
        time_emb_factor=2,
        dropout=0.1,
        channel_mult=channel_mult,
        num_classes=num_classes,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=4,
        num_head_channels=32,
        num_heads_upsample=-1,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_new_attention_order=True,
    )
