# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
from timm.models.vision_transformer import Attention, Mlp


class ResidualBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn="group", stride=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(
            in_planes,
            planes,
            kernel_size=3,
            padding=1,
            stride=stride,
            padding_mode="zeros",
        )
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
        )
        self.relu = nn.ReLU(inplace=True)

        num_groups = planes // 8

        if norm_fn == "group":
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            if not stride == 1:
                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)

        elif norm_fn == "batch":
            self.norm1 = nn.BatchNorm2d(planes)
            self.norm2 = nn.BatchNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.BatchNorm2d(planes)

        elif norm_fn == "instance":
            self.norm1 = nn.InstanceNorm2d(planes)
            self.norm2 = nn.InstanceNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.InstanceNorm2d(planes)

        elif norm_fn == "none":
            self.norm1 = nn.Sequential()
            self.norm2 = nn.Sequential()
            if not stride == 1:
                self.norm3 = nn.Sequential()

        if stride == 1:
            self.downsample = None

        else:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
            )

    def forward(self, x):
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x + y)


class BasicEncoder(nn.Module):
    def __init__(
        self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
    ):
        super(BasicEncoder, self).__init__()
        self.stride = stride
        self.norm_fn = norm_fn
        self.in_planes = 64

        if self.norm_fn == "group":
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
            self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)

        elif self.norm_fn == "batch":
            self.norm1 = nn.BatchNorm2d(self.in_planes)
            self.norm2 = nn.BatchNorm2d(output_dim * 2)

        elif self.norm_fn == "instance":
            self.norm1 = nn.InstanceNorm2d(self.in_planes)
            self.norm2 = nn.InstanceNorm2d(output_dim * 2)

        elif self.norm_fn == "none":
            self.norm1 = nn.Sequential()

        self.conv1 = nn.Conv2d(
            input_dim,
            self.in_planes,
            kernel_size=7,
            stride=2,
            padding=3,
            padding_mode="zeros",
        )
        self.relu1 = nn.ReLU(inplace=True)

        self.shallow = False
        if self.shallow:
            self.layer1 = self._make_layer(64, stride=1)
            self.layer2 = self._make_layer(96, stride=2)
            self.layer3 = self._make_layer(128, stride=2)
            self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
        else:
            self.layer1 = self._make_layer(64, stride=1)
            self.layer2 = self._make_layer(96, stride=2)
            self.layer3 = self._make_layer(128, stride=2)
            self.layer4 = self._make_layer(128, stride=2)

            self.conv2 = nn.Conv2d(
                128 + 128 + 96 + 64,
                output_dim * 2,
                kernel_size=3,
                padding=1,
                padding_mode="zeros",
            )
            self.relu2 = nn.ReLU(inplace=True)
            self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)

        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1):
        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
        layers = (layer1, layer2)

        self.in_planes = dim
        return nn.Sequential(*layers)

    def forward(self, x):
        _, _, H, W = x.shape

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        if self.shallow:
            a = self.layer1(x)
            b = self.layer2(a)
            c = self.layer3(b)
            a = F.interpolate(
                a,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            b = F.interpolate(
                b,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            c = F.interpolate(
                c,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            x = self.conv2(torch.cat([a, b, c], dim=1))
        else:
            a = self.layer1(x)
            b = self.layer2(a)
            c = self.layer3(b)
            d = self.layer4(c)
            a = F.interpolate(
                a,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            b = F.interpolate(
                b,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            c = F.interpolate(
                c,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            d = F.interpolate(
                d,
                (H // self.stride, W // self.stride),
                mode="bilinear",
                align_corners=True,
            )
            x = self.conv2(torch.cat([a, b, c, d], dim=1))
            x = self.norm2(x)
            x = self.relu2(x)
            x = self.conv3(x)

        if self.training and self.dropout is not None:
            x = self.dropout(x)
        return x


class AttnBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """

    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(
            hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
        )

        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(
            in_features=hidden_size,
            hidden_features=mlp_hidden_dim,
            act_layer=approx_gelu,
            drop=0,
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


def bilinear_sampler(img, coords, mode="bilinear", mask=False):
    """Wrapper for grid_sample, uses pixel coordinates"""
    H, W = img.shape[-2:]
    xgrid, ygrid = coords.split([1, 1], dim=-1)
    # go to 0,1 then 0,2 then -1,1
    xgrid = 2 * xgrid / (W - 1) - 1
    ygrid = 2 * ygrid / (H - 1) - 1

    grid = torch.cat([xgrid, ygrid], dim=-1)
    img = F.grid_sample(img, grid, align_corners=True)

    if mask:
        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
        return img, mask.float()

    return img


class CorrBlock:
    def __init__(self, fmaps, num_levels=4, radius=4):
        B, S, C, H, W = fmaps.shape
        self.S, self.C, self.H, self.W = S, C, H, W

        self.num_levels = num_levels
        self.radius = radius
        self.fmaps_pyramid = []

        self.fmaps_pyramid.append(fmaps)
        for i in range(self.num_levels - 1):
            fmaps_ = fmaps.reshape(B * S, C, H, W)
            fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
            _, _, H, W = fmaps_.shape
            fmaps = fmaps_.reshape(B, S, C, H, W)
            self.fmaps_pyramid.append(fmaps)

    def sample(self, coords):
        r = self.radius
        B, S, N, D = coords.shape
        assert D == 2

        H, W = self.H, self.W
        out_pyramid = []
        for i in range(self.num_levels):
            corrs = self.corrs_pyramid[i]  # B, S, N, H, W
            _, _, _, H, W = corrs.shape

            dx = torch.linspace(-r, r, 2 * r + 1)
            dy = torch.linspace(-r, r, 2 * r + 1)
            delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
                coords.device
            )

            centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
            corrs = corrs.view(B, S, N, -1)
            out_pyramid.append(corrs)

        out = torch.cat(out_pyramid, dim=-1)  # B, S, N, LRR*2
        return out.contiguous().float()

    def corr(self, targets):
        B, S, N, C = targets.shape
        assert C == self.C
        assert S == self.S

        fmap1 = targets

        self.corrs_pyramid = []
        for fmaps in self.fmaps_pyramid:
            _, _, _, H, W = fmaps.shape
            fmap2s = fmaps.view(B, S, C, H * W)
            corrs = torch.matmul(fmap1, fmap2s)
            corrs = corrs.view(B, S, N, H, W)
            corrs = corrs / torch.sqrt(torch.tensor(C).float())
            self.corrs_pyramid.append(corrs)


class FeatBlock:
    def __init__(self, fmaps, num_levels=4, radius=4):
        B, S, C, H, W = fmaps.shape
        self.S, self.C, self.H, self.W = S, C, H, W

        self.num_levels = num_levels
        self.radius = radius
        self.fmaps_pyramid = []

        self.fmaps_pyramid.append(fmaps)
        for i in range(self.num_levels - 1):
            fmaps_ = fmaps.reshape(B * S, C, H, W)
            fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
            _, _, H, W = fmaps_.shape
            fmaps = fmaps_.reshape(B, S, C, H, W)
            self.fmaps_pyramid.append(fmaps)

    def sample(self, coords):
        r = self.radius
        B, S, N, D = coords.shape
        assert D == 2

        H, W = self.H, self.W
        out_pyramid = []
        for i in range(self.num_levels):
            # corrs = self.corrs_pyramid[i]  # B, S, N, H, W
            fmaps = self.fmaps_pyramid[i]

            dx = torch.linspace(-r, r, 2 * r + 1)
            dy = torch.linspace(-r, r, 2 * r + 1)
            delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
                coords.device
            )

            centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            fmaps_ = repeat(fmaps, "b s c h w -> (b s n) c h w", n=N)

            fmaps_ = bilinear_sampler(fmaps_, coords_lvl)
            fmaps_ = fmaps_.view(B, S, N, -1)
            out_pyramid.append(fmaps_)

        out = torch.cat(out_pyramid, dim=-1)  # B, S, N, C*LRR*2
        return out.contiguous().float()


class UpdateFormer(nn.Module):
    """
    Transformer model that updates track estimates.
    """

    def __init__(
        self,
        space_depth=12,
        time_depth=12,
        input_dim=320,
        hidden_size=384,
        num_heads=8,
        output_dim=130,
        mlp_ratio=4.0,
        add_space_attn=True,
    ):
        super().__init__()
        self.out_channels = 2
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.add_space_attn = add_space_attn
        self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
        self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)

        self.time_blocks = nn.ModuleList(
            [
                AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
                for _ in range(time_depth)
            ]
        )

        if add_space_attn:
            self.space_blocks = nn.ModuleList(
                [
                    AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
                    for _ in range(space_depth)
                ]
            )
            assert len(self.time_blocks) >= len(self.space_blocks)
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

    def forward(self, input_tensor):
        x = self.input_transform(input_tensor)

        j = 0
        for i in range(len(self.time_blocks)):
            B, N, T, _ = x.shape
            x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
            x_time = self.time_blocks[i](x_time)

            x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
            if self.add_space_attn and (
                i % (len(self.time_blocks) // len(self.space_blocks)) == 0
            ):
                x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
                x_space = self.space_blocks[j](x_space)
                x = rearrange(x_space, "(b t) n c -> b n t c  ", b=B, t=T, n=N)
                j += 1

        flow = self.flow_head(x)
        return flow


class MotionLabelMLP(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=128, S=8):
        super().__init__()

        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.S = S

        self.network = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1),
            Rearrange("b s n c -> b n (s c)", c=1),
            nn.MaxPool1d(kernel_size=S),
        )

    def forward(self, x, coords=None):
        """_summary_

        Args:
            x (_type_): B, S, N, C
        """
        x = self.network(x)
        return x


class MotionLabelMLPV1(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=128, S=8):
        super().__init__()

        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.S = S

        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(
            in_features=self.in_dim,
            hidden_features=hidden_dim,
            out_features=1,
            act_layer=approx_gelu,
            drop=0,
        )

        self.pool = nn.AvgPool1d(kernel_size=S)

    def forward(self, x, coords=None):
        """_summary_

        Args:
            x (_type_): B, S, N, C
        """
        x = self.mlp(x)
        x = rearrange(x, "b s n c -> b n (s c)", c=1)
        x = self.pool(x)
        return x


class MotionLabelATTN(nn.Module):
    def __init__(self, in_dim=128, num_heads=4):
        super().__init__()

        self.in_dim = in_dim
        self.num_heads = num_heads
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.self_attn = nn.MultiheadAttention(
            embed_dim=in_dim, num_heads=num_heads, batch_first=True
        )
        self.mlp = Mlp(
            in_features=self.in_dim,
            hidden_features=in_dim,
            out_features=1,
            act_layer=approx_gelu,
            drop=0,
        )

    def forward(self, x, coords=None):
        """_summary_

        Args:
            x (_type_): B, S, N, C
        """
        B, S, N, C = x.shape
        x = rearrange(x, "b s n c -> (b n) s c")
        attn_output, _ = self.self_attn(x, x, x)
        out = self.mlp(attn_output)
        out = torch.mean(out, dim=1)
        out = rearrange(out, "(b n) c -> b n c", b=B)
        return out


class MotionLabelATTNV1(nn.Module):
    def __init__(
        self, in_dim=128, hidden_dim=128, num_heads=4, mlp_ratio=2.0, add_coord=False
    ):
        super().__init__()
        self.add_coord = add_coord

        if add_coord:
            self.in_dim = in_dim + 2
        else:
            self.in_dim = in_dim

        self.num_heads = num_heads
        self.input_transform = torch.nn.Linear(self.in_dim, hidden_dim, bias=True)
        self.motion_head = torch.nn.Linear(hidden_dim, 1, bias=True)
        self.temporal_attn = AttnBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio)
        self.spatial_attn = AttnBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio)

    def forward(self, x, coords=None):
        """_summary_

        Args:
            x (_type_): B, S, N, C
        """
        B, S, N, C = x.shape

        if self.add_coord:
            x = torch.cat([x, coords.detach()], dim=-1)
        x = self.input_transform(x)

        x_time = rearrange(x, "b s n c -> (b n) s c", b=B, s=S, n=N)
        x_time = self.temporal_attn(x_time)

        x = rearrange(x_time, "(b n) s c -> b n s c", b=B, s=S, n=N)

        x_space = rearrange(x, "b n s c -> (b s) n c", b=B, s=S, n=N)
        x_space = self.spatial_attn(x_space)
        x = rearrange(x_space, "(b s) n c -> b n s c", b=B, s=S, n=N)

        motion = self.motion_head(x)
        motion = torch.mean(motion, dim=2)

        return motion


class MotionLabelMLPV2(nn.Module):
    def __init__(self, in_dim=128, hidden_dim=128):
        super().__init__()

        self.in_dim = in_dim
        self.hidden_dim = hidden_dim

        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(
            in_features=self.in_dim,
            hidden_features=hidden_dim,
            out_features=1,
            act_layer=approx_gelu,
            drop=0,
        )

    def forward(self, x, coords=None):
        """_summary_

        Args:
            x (_type_): B, S, N, C
        """
        x = self.mlp(x)
        return x


class MotionLabelBlock(nn.Module):
    def __init__(self, cfg, S):
        super().__init__()
        self.cfg = cfg.motion_label_block

        if self.cfg.mode == "mlp":
            self.network = MotionLabelMLP(
                in_dim=self.cfg.in_dim, hidden_dim=self.cfg.hidden_dim, S=S
            )
        elif self.cfg.mode == "mlp_v1":
            self.network = MotionLabelMLPV1(
                in_dim=self.cfg.in_dim, hidden_dim=self.cfg.hidden_dim, S=S
            )
        elif self.cfg.mode == "mlp_v2":
            self.network = MotionLabelMLPV2(
                in_dim=self.cfg.in_dim, hidden_dim=self.cfg.hidden_dim
            )
        elif self.cfg.mode == "attn":
            self.network = MotionLabelATTN(
                in_dim=self.cfg.in_dim, num_heads=self.cfg.num_heads
            )
        elif self.cfg.mode == "attn_v1":
            self.network = MotionLabelATTNV1(
                in_dim=self.cfg.in_dim,
                hidden_dim=self.cfg.hidden_dim,
                num_heads=self.cfg.num_heads,
                mlp_ratio=self.cfg.mlp_ratio,
                add_coord=self.cfg.add_coord,
            )
        else:
            raise NotImplementedError

    def forward(self, x, coords=None):
        return self.network(x, coords)
