from collections import OrderedDict
import torch.nn.functional as F
from .BasicModules import *


class DarknetLayer(nn.Module):
    def __init__(self,**kwargs):
        super().__init__()
        self.negative_slope = kwargs.get('negative_slop',0)
        self.norm_choice = kwargs.get('norm_choice','BN')
        self.kernel_size = kwargs.get('kernel_size',3)
        self.conv_choice = kwargs.get('conv_choice','basic')
        num_in_features = kwargs.get('num_in_features',1)
        self.num_in_features = num_in_features
        
        self.conv0 = ConvBlock(negative_slope=self.negative_slope,norm_choice=self.norm_choice,num_in_features=num_in_features,
                        num_out_features=int(num_in_features/2),kernel_size=1)
        self.conv1 = ConvBlock(negative_slope=self.negative_slope,norm_choice=self.norm_choice,num_in_features=int(num_in_features/2),
                        num_out_features=num_in_features,kernel_size=self.kernel_size)
        self.shortcut = DarkNetShortCut()
    
    def forward(self,x):
        x0 = self.conv0(x)
        x0 = self.conv1(x0)
        x = self.shortcut(x0,x)
        return x
    
    
class Darknet(nn.Module):
    def __init__(self,**kwargs):
        super().__init__()
        self.block_inplanes = kwargs.get('block_inplanes')
        self.negative_slope = kwargs.get('negative_slop',0)
        self.norm_choice = kwargs.get('norm_choice','BN')
        self.kernel_size = kwargs.get('kernel_size',3)
        self.num_stages = kwargs.get('num_stages',4)
        self.pooling_ratios = kwargs.get('pooling_ratios')
        self.repeat_time_list = kwargs.get('repeat_time_list')
        self.dropout_rate = kwargs.get('dropout_rate',0)
        self.num_class = kwargs.get('num_class',1)
        self.conv_choice = kwargs.get('conv_choice','basic')
        n_input_channel = kwargs.get('n_input_channel',1)
        
        n_in_features = int(self.block_inplanes[0]/2)
        
        layer_map_dict = {}
        layer_map_dict['conv_0'] = ConvBlock(negative_slope=self.negative_slope,norm_choice=self.norm_choice,num_in_features=n_input_channel,
                        num_out_features=n_in_features,kernel_size=3)
        
        num_in_features = n_in_features
        
        final_feature_num = num_in_features
        for stage_idx in range(self.num_stages):
            output_features = self.block_inplanes[stage_idx]
            layer_map_dict['Stage%02d_PreConv'%(stage_idx+1)] = ConvBlock(negative_slope=self.negative_slope,norm_choice=self.norm_choice,num_in_features=num_in_features,
                        num_out_features=output_features,kernel_size=3)
            for repeat_idx in range(self.repeat_time_list[stage_idx]):
                layer_map_dict['Stage%02d_InnerConvBlock%02d'%(stage_idx+1,repeat_idx+1)] = self._make_layer(output_features)
            if self.dropout_rate>0:
                layer_map_dict['Stage%02d_Dropout'%(stage_idx+1)] = nn.Dropout3d(p=self.dropout_rate)
            if self.pooling_ratios[stage_idx]>1:
                layer_map_dict['Stage%02d_MP'%(stage_idx+1)] = nn.MaxPool3d(kernel_size=self.pooling_ratios[stage_idx])
            final_feature_num = output_features
            num_in_features = output_features
            
        self.dark_layer_func = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()]))
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc0 = nn.Linear(final_feature_num,int(final_feature_num/4))
        self.dp0 = nn.Dropout3d(p=0.5)
        self.fc = nn.Linear(int(final_feature_num/4),self.num_class)
    
        
    def _make_layer(self,num_in_features):
        '''
        Generate structures like 
        input-conv1*1-conv3*3-shortcut
        '''
        func = DarknetLayer(negative_slope = self.negative_slope,norm_choice = self.norm_choice,kernel_size=self.kernel_size,
                      conv_choice = self.conv_choice,num_in_features = num_in_features)
        
        return func
          
    def forward(self,x):
        x = self.dark_layer_func(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc0(x)
        x = self.dp0(x)
        x = self.fc(x)
        return x