import torch
from torch import nn
import torch.nn.functional as F
import os

__all__ = ['HRNet', 'hrnetv2_48', 'hrnetv2_32']

# Checkpoint path of pre-trained backbone (edit to your path). Download backbone pretrained model hrnetv2-32 @
# https://drive.google.com/file/d/1NxCK7Zgn5PmeS7W1jYLt5J9E0RRZ2oyF/view?usp=sharing .Personally, I added the backbone
# weights to the folder /checkpoints

model_urls = {
    'hrnetv2_32': './checkpoints/model_best_epoch96_edit.pth',
    'hrnetv2_48': None
}


def check_pth(arch):
    CKPT_PATH = model_urls[arch]
    if os.path.exists(CKPT_PATH):
        print(f"Backbone HRNet Pretrained weights at: {CKPT_PATH}, only usable for HRNetv2-32")
    else:
        print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
    return CKPT_PATH
    # HRNetv2-48 not available yet, but you can train the whole model from scratch.


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class StageModule(nn.Module):
    def __init__(self, stage, output_branches, c):
        super(StageModule, self).__init__()

        self.number_of_branches = stage  # number of branches is equivalent to the stage configuration.
        self.output_branches = output_branches

        self.branches = nn.ModuleList()

        # Note: Resolution + Number of channels maintains the same throughout respective branch.
        for i in range(self.number_of_branches):  # Stage scales with the number of branches. Ex: Stage 2 -> 2 branch
            channels = c * (2 ** i)  # Scale channels by 2x for branch with lower resolution,

            # Paper does x4 basic block for each forward sequence in each branch (x4 basic block considered as a block)
            branch = nn.Sequential(*[BasicBlock(channels, channels) for _ in range(4)])

            self.branches.append(branch)  # list containing all forward sequence of individual branches.

        # For each branch requires repeated fusion with all other branches after passing through x4 basic blocks.
        self.fuse_layers = nn.ModuleList()

        for branch_output_number in range(self.output_branches):

            self.fuse_layers.append(nn.ModuleList())

            for branch_number in range(self.number_of_branches):
                if branch_number == branch_output_number:
                    self.fuse_layers[-1].append(nn.Sequential())  # Used in place of "None" because it is callable
                elif branch_number > branch_output_number:
                    self.fuse_layers[-1].append(nn.Sequential(
                        nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=1, stride=1,
                                  bias=False),
                        nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
                                       track_running_stats=True),
                        nn.Upsample(scale_factor=(2.0 ** (branch_number - branch_output_number)), mode='nearest'),
                    ))
                elif branch_number < branch_output_number:
                    downsampling_fusion = []
                    for _ in range(branch_output_number - branch_number - 1):
                        downsampling_fusion.append(nn.Sequential(
                            nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_number), kernel_size=3, stride=2,
                                      padding=1,
                                      bias=False),
                            nn.BatchNorm2d(c * (2 ** branch_number), eps=1e-05, momentum=0.1, affine=True,
                                           track_running_stats=True),
                            nn.ReLU(inplace=True),
                        ))
                    downsampling_fusion.append(nn.Sequential(
                        nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=3,
                                  stride=2, padding=1,
                                  bias=False),
                        nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
                                       track_running_stats=True),
                    ))
                    self.fuse_layers[-1].append(nn.Sequential(*downsampling_fusion))

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        # input to each stage is a list of inputs for each branch
        x = [branch(branch_input) for branch, branch_input in zip(self.branches, x)]

        x_fused = []
        for branch_output_index in range(
                self.output_branches):  # Amount of output branches == total length of fusion layers
            for input_index in range(self.number_of_branches):  # The inputs of other branches to be fused.
                if input_index == 0:
                    x_fused.append(self.fuse_layers[branch_output_index][input_index](x[input_index]))
                else:
                    x_fused[branch_output_index] = x_fused[branch_output_index] + self.fuse_layers[branch_output_index][
                        input_index](x[input_index])

        # After fusing all streams together, you will need to pass the fused layers
        for i in range(self.output_branches):
            x_fused[i] = self.relu(x_fused[i])

        return x_fused  # returning a list of fused outputs


class HRNet(nn.Module):
    def __init__(self, c=48, num_blocks=[1, 4, 3], num_classes=1000):
        super(HRNet, self).__init__()

        # Stem:
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
        self.relu = nn.ReLU(inplace=True)

        # Stage 1:
        downsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(256, eps=1e-05, affine=True, track_running_stats=True),
        )
        # Note that bottleneck module will expand the output channels according to the output channels*block.expansion
        bn_expansion = Bottleneck.expansion  # The channel expansion is set in the bottleneck class.
        self.layer1 = nn.Sequential(
            Bottleneck(64, 64, downsample=downsample),  # Input is 64 for first module connection
            Bottleneck(bn_expansion * 64, 64),
            Bottleneck(bn_expansion * 64, 64),
            Bottleneck(bn_expansion * 64, 64),
        )

        # Transition 1 - Creation of the first two branches (one full and one half resolution)
        # Need to transition into high resolution stream and mid resolution stream
        self.transition1 = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(256, c, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(c, eps=1e-05, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ),
            nn.Sequential(nn.Sequential(  # Double Sequential to fit with official pretrained weights
                nn.Conv2d(256, c * 2, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(c * 2, eps=1e-05, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            )),
        ])

        # Stage 2:
        number_blocks_stage2 = num_blocks[0]
        self.stage2 = nn.Sequential(
            *[StageModule(stage=2, output_branches=2, c=c) for _ in range(number_blocks_stage2)])

        # Transition 2  - Creation of the third branch (1/4 resolution)
        self.transition2 = self._make_transition_layers(c, transition_number=2)

        # Stage 3:
        number_blocks_stage3 = num_blocks[1]  # number blocks you want to create before fusion
        self.stage3 = nn.Sequential(
            *[StageModule(stage=3, output_branches=3, c=c) for _ in range(number_blocks_stage3)])

        # Transition  - Creation of the fourth branch (1/8 resolution)
        self.transition3 = self._make_transition_layers(c, transition_number=3)

        # Stage 4:
        number_blocks_stage4 = num_blocks[2]  # number blocks you want to create before fusion
        self.stage4 = nn.Sequential(
            *[StageModule(stage=4, output_branches=4, c=c) for _ in range(number_blocks_stage4)])

        # Classifier (extra module if want to use for classification):
        # pool, reduce dimensionality, flatten, connect to linear layer for classification:
        out_channels = sum([c * 2 ** i for i in range(len(num_blocks)+1)])  # total output channels of HRNetV2
        pool_feature_map = 8
        self.bn_classifier = nn.Sequential(
            nn.Conv2d(out_channels, out_channels // 4, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels // 4, eps=1e-05, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(pool_feature_map),
            nn.Flatten(),
            nn.Linear(pool_feature_map * pool_feature_map * (out_channels // 4), num_classes),
        )

    @staticmethod
    def _make_transition_layers(c, transition_number):
        return nn.Sequential(
            nn.Conv2d(c * (2 ** (transition_number - 1)), c * (2 ** transition_number), kernel_size=3, stride=2,
                      padding=1, bias=False),
            nn.BatchNorm2d(c * (2 ** transition_number), eps=1e-05, affine=True,
                           track_running_stats=True),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Stem:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        # Stage 1
        x = self.layer1(x)
        x = [trans(x) for trans in self.transition1]  # split to 2 branches, form a list.

        # Stage 2
        x = self.stage2(x)
        x.append(self.transition2(x[-1]))

        # Stage 3
        x = self.stage3(x)
        x.append(self.transition3(x[-1]))

        # Stage 4
        x = self.stage4(x)

        # HRNetV2 Example: (follow paper, upsample via bilinear interpolation and to highest resolution size)
        output_h, output_w = x[0].size(2), x[0].size(3)  # Upsample to size of highest resolution stream
        x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
        x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
        x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)

        # Upsampling all the other resolution streams and then concatenate all (rather than adding/fusing like HRNetV1)
        x = torch.cat([x[0], x1, x2, x3], dim=1)
        x = self.bn_classifier(x)
        return x


def _hrnet(arch, channels, num_blocks, pretrained, progress, **kwargs):
    model = HRNet(channels, num_blocks, **kwargs)
    if pretrained:
        CKPT_PATH = check_pth(arch)
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint['state_dict'])
    return model


def hrnetv2_48(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
    w_channels = 48
    return _hrnet('hrnetv2_48', w_channels, number_blocks, pretrained, progress,
                  **kwargs)


def hrnetv2_32(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
    w_channels = 32
    return _hrnet('hrnetv2_32', w_channels, number_blocks, pretrained, progress,
                  **kwargs)


if __name__ == '__main__':

    try:
        CKPT_PATH = os.path.join(os.path.abspath("."), '../../checkpoints/hrnetv2_32_model_best_epoch96.pth')
        print("--- Running file as MAIN ---")
        print(f"Backbone HRNET Pretrained weights as __main__ at: {CKPT_PATH}")
    except:
        print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")

    # Models
    model = hrnetv2_32(pretrained=True)
    #model = hrnetv2_48(pretrained=False)

    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    model.to(device)
    in_ = torch.ones(1, 3, 768, 768).to(device)
    y = model(in_)
    print(y.shape)

    # Calculate total number of parameters:
    # pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # print(pytorch_total_params)






