import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict import sys # sys.path.append('../') from ..baseLayers.BasicModules import * class UnetPPUpSample(nn.Module): def __init__(self,**kwargs): super().__init__() ''' base_params: including negative_slope/kernel_size/...... ''' self.base_params = kwargs.get('base_params') self.upsample_rate = kwargs.get('upsample_rate') self.num_in_filters = kwargs.get('num_in_filters') self.num_out_filters = kwargs.get('num_out_filters') self.num_layers = kwargs.get('num_layers') dropout_rate = kwargs.get('dropout_rate',0.0) self.base_params['num_in_features'] = self.num_in_filters self.base_params['num_out_features'] = self.num_out_filters self.base_params['repeat_times'] = self.num_layers self.convLayer = ConvStage(**self.base_params) self.convLayer.cuda() if dropout_rate>0: self.dropout_layer = nn.Dropout3d(p=dropout_rate) else: self.dropout_layer = None def forward(self,x,same_stage_tensors): if self.upsample_rate>1: x = nn.Upsample(scale_factor=self.upsample_rate)(x) if len(same_stage_tensors)>0: x = torch.cat([x]+same_stage_tensors,1) output = self.convLayer(x) if self.dropout_layer is not None: output = self.dropout_layer(output) return output class UnetPP(nn.Module): def __init__(self,**kwargs): super().__init__() model_parameters = kwargs.get('model_params') self.model_parameters = model_parameters self.basic_model_params = {'conv_func':model_parameters.get('conv_func','basic'),'norm_choice':model_parameters.get('norm_choice'), 'activation_func':model_parameters.get('activation_func',None),'kernel_size':model_parameters.get('kernel_size',3), 'norm_axis':model_parameters.get('norm_axis'),'negative_slop':model_parameters.get('negative_slop')} self.layer_repeat_times = model_parameters.get('layer_repeat_times') encoder_model = model_parameters.get('encoder_model') self.dropout_rate = model_parameters.get('dropout_rate',0.0) self.real_strides = encoder_model.GetRealStrides()[1:] output_features = encoder_model.GetOutputFeatures() self.output_features = output_features self.encoder_model = encoder_model num_stages = len(self.real_strides) self.xnet_layers = [[None for i in range(num_stages)] for j in range(num_stages)] for col_idx in range(1,num_stages): for layer_idx in range(num_stages-col_idx-1,-1,-1): current_layer = self._GetUnetPPInsideLayer(layer_idx,col_idx) self.xnet_layers[layer_idx][col_idx] = current_layer ################## output layer final_kernel_size = 1 self.output_layer0 = nn.Conv3d(in_channels=output_features[0],out_channels=model_parameters['seg_num_class'], kernel_size=final_kernel_size,padding=[int(final_kernel_size/2) for _ in range(3)]) self.output_layer1 = nn.Conv3d(in_channels=output_features[0],out_channels=model_parameters['seg_num_class'], kernel_size=final_kernel_size,padding=[int(final_kernel_size/2) for _ in range(3)]) def _GetUnetPPInsideLayer(self,layer_idx,col_idx): input_features = self.output_features[layer_idx]*(col_idx)+self.output_features[layer_idx+1] output_features = self.output_features[layer_idx] current_layer = UnetPPUpSample(base_params = self.basic_model_params,upsample_rate = self.real_strides[layer_idx], num_in_filters = input_features,num_out_filters=output_features,num_layers=self.layer_repeat_times, dropout_rate = self.dropout_rate) return current_layer def forward(self,x): encoder_tensors = self.encoder_model(x) num_stages = len(self.xnet_layers) xnet_tensors = [[None for i in range(num_stages)] for j in range(num_stages)] for layer_idx in range(num_stages): xnet_tensors[layer_idx][0] = encoder_tensors[layer_idx] for col_idx in range(1,num_stages): for layer_idx in range(num_stages-col_idx-1,-1,-1): same_stage_tensors = xnet_tensors[layer_idx][:col_idx] current_tensor = xnet_tensors[layer_idx+1][col_idx-1] xnet_tensors[layer_idx][col_idx] = self.xnet_layers[layer_idx][col_idx](current_tensor,same_stage_tensors) output0 = self.output_layer0(xnet_tensors[0][-1]) output1 = self.output_layer1(xnet_tensors[0][-2]) return output0,output1