import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm3d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class ConvStage(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.negative_slop = kwargs.get('negative_slop', 0)
        self.norm_choice = kwargs.get('norm_choice', 'BN')
        self.num_in_features = kwargs.get('num_in_features')
        self.num_out_features = kwargs.get('num_out_features')
        self.kernel_size = kwargs.get('kernel_size', 3)
        self.repeat_times = kwargs.get('repeat_times', 3)
        self.func_list = {}
        for idx in range(self.repeat_times):
            if idx == 0:
                current_func = ConvBlock(negative_slop=self.negative_slop, norm_choice=self.norm_choice,
                                         num_in_features=self.num_in_features,
                                         num_out_features=self.num_out_features, kernel_size=self.kernel_size)
            else:
                current_func = ConvBlock(negative_slop=self.negative_slop, norm_choice=self.norm_choice,
                                         num_in_features=self.num_out_features,
                                         num_out_features=self.num_out_features, kernel_size=self.kernel_size)
            self.func_list['convBlock_%d' % idx] = current_func
        self.convStage_func = nn.Sequential(OrderedDict([(key, self.func_list[key]) for key in self.func_list.keys()]))

    def forward(self, x):
        output = x
        return self.convStage_func(x)


class ConvBlock(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.negative_slop = kwargs.get('negative_slop', 0)
        self.norm_choice = kwargs.get('norm_choice', 'BN')
        self.num_in_features = kwargs.get('num_in_features')
        self.num_out_features = kwargs.get('num_out_features')
        self.kernel_size = kwargs.get('kernel_size', 3)

        self.conv = nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features,
                              kernel_size=self.kernel_size, padding=[int(self.kernel_size / 2) for _ in range(3)])
        self.post_func = BNReLU(negative_slop=self.negative_slop, norm_choice=self.norm_choice,
                                num_features=self.num_out_features)

    def forward(self, x):
        output = x
        output = self.conv(output)
        output = self.post_func(output)
        return output


class BNReLU(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.negative_slop = kwargs.get('negative_slop', 0)
        self.norm_choice = kwargs.get('norm_choice', 'BN')
        self.num_features = kwargs.get('num_features')

        if self.negative_slop <= 0:
            self.relu = nn.ReLU()
        else:
            self.relu = nn.LeakyReLU(negative_slope=self.negative_slop)
        if self.norm_choice == 'IN':
            self.norm = nn.InstanceNorm3d(num_features=self.num_features)
        else:
            self.norm = nn.BatchNorm3d(num_features=self.num_features)

    def forward(self, x):
        x = self.norm(x)
        x = self.relu(x)
        return x


class ConvFunc3D(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.num_in_features = kwargs.get('num_in_features')
        self.num_out_features = kwargs.get('num_out_features')
        self.stride = kwargs.get('stride', 1)
        self.conv_choice = kwargs.get('conv_choice', 'basic')
        self.kernel_size = kwargs.get('kernel_size', 1)
        if self.conv_choice == 'basic' or self.kernel_size == 1:
            if self.kernel_size == 1:
                self.conv = nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features,
                                      kernel_size=self.kernel_size,
                                      stride=self.stride)
            else:
                self.conv = nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features,
                                      kernel_size=self.kernel_size,
                                      stride=self.stride, padding=[int(self.kernel_size / 2) for _ in range(3)])
        else:
            self.conv = nn.Sequential(
                nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_in_features,
                          kernel_size=self.kernel_size,
                          stride=self.stride, padding=[int(self.kernel_size / 2) for _ in range(3)],
                          groups=self.num_in_features),
                nn.Conv3d(in_channels=self.num_in_features, out_channels=self.num_out_features, kernel_size=1)
            )

    def forward(self, x):
        return self.conv(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, **kwargs):
        super().__init__()
        # self.expansion = 4
        self.negative_slope = kwargs.get('negative_slop', 0)
        self.norm_choice = kwargs.get('norm_choice', 'BN')
        self.num_in_features = kwargs.get('num_in_features')
        self.num_out_features = kwargs.get('num_out_features')
        self.kernel_size = kwargs.get('kernel_size', 3)
        self.stride = kwargs.get('stride', 1)
        self.downsample = kwargs.get('downsample', None)
        self.conv_choice = kwargs.get('conv_choice', 'basic')

        self.conv1 = ConvFunc3D(num_in_features=self.num_in_features, num_out_features=self.num_out_features,
                                conv_choice=self.conv_choice)
        self.bn1 = BNReLU(negative_slop=self.negative_slope, norm_choice=self.norm_choice,
                          num_features=self.num_out_features)

        self.conv2 = ConvFunc3D(num_in_features=self.num_out_features, num_out_features=self.num_out_features,
                                stride=self.stride, kernel_size=self.kernel_size, conv_choice=self.conv_choice)
        self.bn2 = BNReLU(negative_slop=self.negative_slope, norm_choice=self.norm_choice,
                          num_features=self.num_out_features)

        self.conv3 = ConvFunc3D(num_in_features=self.num_out_features,
                                num_out_features=self.num_out_features * self.expansion, conv_choice=self.conv_choice)
        self.bn3 = BNReLU(negative_slop=self.negative_slope, norm_choice=self.norm_choice,
                          num_features=self.num_out_features * self.expansion)

        self.relu = nn.LeakyReLU(negative_slope=self.negative_slope)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)

        return out


class DarkNetShortCut(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, x, y):
        if x.shape[1] == y.shape[1]:
            return x + y
        elif x.shape[1] < y.shape[1]:
            y[:x.shape[1]] += x.shape[1]
            return y
        else:
            x[:y.shape[1]] += y.shape[1]
            return x


class SCIModule(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        n_input_channel = kwargs.get('n_input_channel')
        n_output_channel = kwargs.get('n_output_channel')
        self.conv_choice = kwargs.get('conv_choice')
        self.kernel_size = kwargs.get('kernel_size', 3)
        self.softmax = nn.Softmax(dim=2)

        self.conv_func = ConvBlock(num_in_features=n_input_channel, num_out_features=n_output_channel,
                                   negative_slop=0.05,
                                   kernel_size=self.kernel_size)
        self.conv_func1x1 = ConvBlock(num_in_features=n_input_channel, num_out_features=n_output_channel,
                                      negative_slop=0.05,
                                      kernel_size=1)

    def forward(self, x):
        '''
        1. reshape
        2.
        '''
        batch_size, channel, depth, height, width = x.shape
        x0 = x.view((batch_size, channel, -1))
        x_t = torch.transpose(x0, 1, 2)
        weights = torch.zeros((batch_size, channel, channel)).cuda()
        for slice_idx in range(batch_size):
            weights[slice_idx] = -torch.matmul(x0[slice_idx], x_t[slice_idx])

        norm_weights = self.softmax(weights)

        weighted_feature = torch.zeros_like(x0)
        for slice_idx in range(batch_size):
            weighted_feature[slice_idx] = torch.matmul(norm_weights[slice_idx], x0[slice_idx])
        weighted_feature = weighted_feature.view((batch_size, channel, depth, height, width))
        output = self.conv_func(weighted_feature)

        final_output = self.conv_func1x1(output + x)
        return final_output