import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import sys
# sys.path.append('../')
from ..baseLayers.BasicModules import *

class NoduleSegConvBlock_proxima(nn.Module):
    def __init__(self,**kwargs):
        super().__init__()
        basic_parameters = kwargs.get('basic_parameters')
        model_parameters = kwargs.get('stage_parameters')
        self.model_parameters = model_parameters
        stage_idx = model_parameters['stage_idx']
        self.stage_idx = stage_idx
        layer_map_dict = {}
        
        num_in_features = model_parameters['num_in_features']
        num_out_features = model_parameters['num_filters']
        
        layer_parameters = {'num_out_features':num_out_features,
                     'atrous_rate':model_parameters['atrous_rate'],
                     'ASPP_mode':model_parameters['ASPP_mode']}
       
        layer_parameters.update(basic_parameters)
        layer_parameters['num_out_features'] = num_out_features
        
        
        self.first_layer_of_current_stage_dict = {}
        
        for layer_idx in range(self.model_parameters['repeat_time']):
            
            conv_stride = int(model_parameters.get('stride',1)) if layer_idx==0 else 1
            if layer_idx==0:
                self.conv_stride = conv_stride
            
            layer_parameters['stride'] = conv_stride
            layer_parameters['num_in_features'] = num_in_features
            
            if model_parameters.get('ASPP_choice',-1)>0:
                if layer_idx==0:
                    self.first_layer_of_current_stage_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_ASPP'%(stage_idx,layer_idx+1)] = ASPP(**layer_parameters)
                else:
                    layer_map_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_ASPP'%(stage_idx,layer_idx+1)] = ASPP(**layer_parameters)

                layer_parameters['stride'] = 1
                if model_parameters['ASPP_mode']==1:
                    layer_parameters['num_in_features'] = num_in_features+num_out_features
                else:
                    layer_parameters['num_in_features'] = num_out_features
            
            if layer_idx==0 :
                self.first_layer_of_current_stage_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_conv'%(stage_idx,layer_idx+1)] = ConvBlock(**layer_parameters)
            else:
                layer_map_dict['NoduleSegConvBlock_proxima_stage%02d_layer%02d_conv'%(stage_idx,layer_idx+1)] = ConvBlock(**layer_parameters)
            num_in_features = num_out_features
        self.first_layer_of_current_stage =     nn.Sequential(OrderedDict([(key,self.first_layer_of_current_stage_dict[key]) for key in self.first_layer_of_current_stage_dict.keys()]))
        self.stage_layers = nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()]))
                           
        
    def forward(self,x):
        x1 = self.first_layer_of_current_stage(x)
        output = self.stage_layers(x1)
        return output+x1
    
    def GetEnlarge(self):
        return self.conv_stride




class NoduleSegEncoder_proxima(nn.Module):
    def __init__(self,**kwargs):
        super().__init__()
        r'''
            
        '''
        model_parameters = kwargs.get('model_params')
        self.model_parameters = model_parameters
        basic_model_params = {'conv_func':model_parameters.get('conv_func','basic'),'norm_choice':model_parameters.get('norm_choice'),
                      'activation_func':model_parameters.get('activation_func',None),'kernel_size':model_parameters.get('kernel_size',3),
                        'norm_axis':model_parameters.get('norm_axis'),'negative_slop':model_parameters.get('negative_slop'),
                        'padding_func':model_parameters.get('padding_func','zero')}
        
        
        ASPP_choice = model_parameters.get('ASPP_choice',1)
        dropout_rate = model_parameters.get('dropout_rate',0.0)
        num_blocks = model_parameters.get('num_blocks',5)
        strides = model_parameters.get('strides')
        atrous_rates = model_parameters.get('atrous_rates')
        base_filters = model_parameters.get('base_filters',32)
        num_classes = model_parameters.get('num_classes',1)
        ASPP_block_list = model_parameters.get('ASPP_block_list',[num_blocks-1])
        repeat_times_list = model_parameters.get("repeat_times_list",3)
        ASPP_pooling_choice = model_parameters.get('ASPP_pooling_choice',True)
        
        do_cls = model_parameters.get('do_cls',False)
    
        repeat_times_list = [repeat_times_list for _ in range(num_blocks)] if type(repeat_times_list)==int else repeat_times_list
        self.real_stride_ratios = []
        self.output_features = []
        
        
        num_in_features = 1
        layers = []
        for idx in range(num_blocks):
            layer_map_dict = {}
            current_ASPP_choice = True if (idx in ASPP_block_list) and ASPP_choice else False
            current_layer_parameters = {'atrous_rate':atrous_rates[idx],'stride':strides[idx],'num_filters':base_filters*(2**idx),
                                'repeat_time':repeat_times_list[idx],'ASPP_choice':current_ASPP_choice,
                                'stage_idx':idx+1,'num_in_features':num_in_features,
                                'ASPP_pooling_choice':ASPP_pooling_choice,"ASPP_mode":ASPP_choice}

            current_layer = NoduleSegConvBlock_proxima(basic_parameters=basic_model_params,stage_parameters=current_layer_parameters)
            current_enlarge_ratio = current_layer.GetEnlarge()
            self.real_stride_ratios.append(current_enlarge_ratio)
            layer_map_dict['NoduleSegEncoder_proxima_stage%02d_func'%(idx+1)] = current_layer
            if dropout_rate>0:
                layer_map_dict['NoduleSegEncoder_proxima_stage%02d_dropout'%(idx+1)] = nn.Dropout3d(p=dropout_rate)
            num_in_features = base_filters*(2**idx)
            self.output_features.append(num_in_features)
            layers.append(nn.Sequential(OrderedDict([(key,layer_map_dict[key]) for key in layer_map_dict.keys()])))
            
        self.encoder_layer0 = layers[0]
        self.encoder_layer1 = layers[1]
        self.encoder_layer2 = layers[2]
        self.encoder_layer3 = layers[3] if len(layers)>3 else None
        self.encoder_layer4 = layers[4] if len(layers)>4 else None
        if do_cls:
            self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
            self.fc0 = nn.Linear(self.output_features[-1],1)
        self.do_cls = do_cls
        

    def GetRealStrides(self):
        return self.real_stride_ratios
    
    def GetOutputFeatures(self):
        return self.output_features
    
    def forward(self,x):
        x1 = self.encoder_layer0(x)
        x2 = self.encoder_layer1(x1)
        x3 = self.encoder_layer2(x2)
        output = [x1,x2,x3]
        if self.encoder_layer3 is not None:
            x4 = self.encoder_layer3(x3)
            output.append(x4)
        if self.encoder_layer4 is not None:
            x5 = self.encoder_layer4(x4)
            output.append(x5)
        if self.do_cls:
            gap_x = self.avgpool(output[-1])
            gap_x = gap_x.view(x.size(0), -1)
            output = self.fc0(gap_x)
#             print ('output is',output)
        return output
    
if __name__ == "__main__":
    cfg_path = '../../files/modelParams/model_parameters.json'
    model_name = "NoduleSegEncoder_proxima"
    sys.path.append('/hdd/disk6/zrx/projects/Pulmonary_Nodule/develop_code/grouplung/lungNoduleDensityClassification/nodulecls/utils/')
    from RWconfig import LoadJson 
    cfg = LoadJson(cfg_path)[model_name]
    
    from torchsummary import summary
    model = NoduleSegEncoder_proxima(model_params=cfg)
#     summary(model.cuda(),(1,24,24,24))
#     total_params = sum(p.numel() for p in model.parameters())
#     print ('output of model is',model.output[0])