import os import sys import six import json import glob import shutil import logging import datetime import numpy as np import scipy.ndimage as nd from tqdm import tqdm from matplotlib import pyplot as plt import torch import torch.nn as nn from torch import optim from torch.utils.data import DataLoader from tensorboard_logger import Logger sys.path.append('../../lungNoduleDensityClassification/nodulecls/utils') from RWconfig import LoadJson,WriteJson from OfflineEval import * sys.path.append('utils') sys.path.append('model') from modelBuild import build_model from baseLayers.BasicModules import weight_init from SegmentationLoss import Tversky_loss,DiceLoss,SurfaceLoss,DiceLoss_TwoLabels from SegGenCOVIDV2 import SegGenCOVID from ImagePreprocess import * from ResultEval import dice_coefficient_post from respacing_func import * from cls_loss import * class Segmentation3D(object): def __init__(self,filename,mode,logger): self.filename = filename self.logger = logger self.mode = mode self.parseParams() self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def __call__(self): self._makeTensorBoard() self.buildModel() self.compileModel() self.loadData() if self.mode == 'training': self.Train() else: self.MakeInference() def _makeTensorBoard(self): tensorboard_path = "%s/tensorboard_logs"%self.save_path self.tensorboard_logger = Logger(logdir=tensorboard_path,flush_secs=10) self.logger.info('tensorboard path is %s'%tensorboard_path) def parseParams(self): cfg = LoadJson(self.filename) self.cfg = cfg[self.mode] self.common_cfg = cfg['common'] if self.mode=='training': current_DT = datetime.datetime.now() self.current_day = '%02d%02d'%(current_DT.month,current_DT.day) model_order = '%02d'%self.cfg['model_order'] self.save_path = self.common_cfg['base_path']+'/'+self.current_day+'_'+self.cfg['model_name']+'/'+str(model_order)+'_'+self.cfg['prefix'] else: self.pre_model_file = self.cfg['pre_model_file'] self.save_path = '/'.join(self.cfg['pre_model_file'].split('/')[:-1]) self.logger.debug('save path is %s'%self.save_path) if not os.path.exists(self.save_path): os.makedirs(self.save_path) target_filename = os.path.join(self.save_path,self.filename.split('/')[-1]) shutil.copy(self.filename,target_filename) def loadData(self): input_size = [self.cfg['num_slice'],self.cfg['voxel_size'],self.cfg['voxel_size']] generator_parameters = {'aug':self.cfg.get('aug',False), 'HU_min':self.common_cfg['normalizer'][0],'HU_max':self.common_cfg['normalizer'][1], 'input_size':input_size,'mode':self.mode, 'repeat_times':self.cfg.get('repeat_times',4), 'estimate_patch_per_img':self.common_cfg.get('estimate_patch_per_img',12), 'area_th_val':self.cfg.get('area_th_val',0), 'use_npy':self.common_cfg.get('use_npy',False), 'val_mode':self.cfg.get('val_mode',1), 'infer_stride':self.cfg.get('infer_stride',1), 'learn_feature':self.cfg.get('learn_feature',False), 'bbox_margin':self.cfg.get('bbox_margin',0), 'norm_func':self.cfg.get('norm_func','HU'), 'center_region_choice':self.cfg.get('center_region_choice',False), 'patch_pos_ratio_range':eval(self.cfg.get('patch_pos_ratio_range',"[0,1]")), 'cache_size':self.cfg.get('cache_size',100), 'logger':self.logger} train_data_files = self.common_cfg["train_file"] val_data_files = self.common_cfg["val_file"] test_data_files = self.common_cfg["test_file"] train_paths = self._parseDataJosn(train_data_files) val_paths = self._parseDataJosn(val_data_files) test_paths = self._parseDataJosn(test_data_files) print ('train_paths is ',train_paths) print ('val_paths is',val_paths) print ('test_paths',test_paths) self.test_paths = test_paths if self.mode=='training': self.train_generator = SegGenCOVID(df_paths=train_paths,**generator_parameters) generator_parameters['aug'] = False generator_parameters['mode'] = 'val' print ('='*60) print ('length of train generator is',len(self.train_generator)) self.val_generator = SegGenCOVID(df_paths=val_paths,**generator_parameters) print ('length of val generator is',len(self.val_generator)) else: self.test_generator = SegGenCOVID(df_paths=test_paths,**generator_parameters) def _parseDataJosn(self,filename): df_files = [] target_filename = os.path.join(self.save_path,filename.split('/')[-1]) filename = 'files/%s'%filename files = LoadJson(filename) shutil.copy(filename,target_filename) for record in files: df_file = record["df_path"] df_files.append(df_file) return df_files def _findLast(self,paths): length = len('CP_epoch') vals = [int(x.split('/')[-1][length:]) for x in paths] max_val = max(vals) idx = [idx for idx in range(len(vals)) if vals[idx]==max_val][0] return paths[idx] def loadModel(self,path=None): if path is not None: model_path = path pretrain = model_path else: pretrain = self.cfg['pre_model_file'] self.pretrain = pretrain model_path = self.pretrain print ('path is %s'%path) print ('model path is %s'%model_path) if os.path.exists(model_path): model_path = model_path elif len(glob.glob(pretrain+'/*pth'))>0: paths = glob.glob(pretrain+'/*pth') model_path = self._findLast(paths) else: model_path = None print ('============================================================================== model_path is %s'%model_path) if model_path is not None: try: self.model.to(device=self.device) print ('===================================== put model to device') self.model.load_state_dict(torch.load(model_path,map_location=self.device)) except: self.logger.error('Cannot load pretrain weight %s'%model_path) def Train(self): train_batch_size = self.cfg['train_batch_size'] val_batch_size = self.cfg['val_batch_size'] num_epoch = self.cfg['n_epochs'] compare_choice = self.cfg.get('compare_choice',False) cls_choice = self.common_cfg.get('do_cls',False) flooding_val = self.cfg.get('flooding_val',0) loss_names = { 0:'branch_1_dice_loss', 1:'branch_2_dice_loss', 2:'branch_3_dice_loss', 3:'branch_1_surface_loss', 4:'branch_2_surface_loss', 5:'branch_3_surface_loss', 6:'branch_1_trans_dice_loss', 7:'branch_2_trans_dice_loss', 8:'branch_3_trans_dice_loss', 9:'branch_1_trans_surface_loss', 10:'branch_2_trans_surface_loss', 11:'branch_3_trans_surface_loss', 12:'branch_1_sim_loss', 13:'branch_2_sim_loss', 14:'branch_3_sim_loss', 15:'cls_loss' } ''' dataloader ''' train_Loader = DataLoader(self.train_generator,batch_size=self.cfg['train_batch_size'],shuffle=True) val_Loader = DataLoader(self.val_generator,batch_size=self.cfg['val_batch_size'],shuffle=True) n_train = len(self.train_generator) label_dtype = torch.long ''' 载入模型 ''' self.loadModel() for epoch in range(num_epoch): self.model.train() self.model.to(self.device) self.model.cuda() epoch_loss = 0 self.train_losses = [[] for _ in range(len(loss_names.keys()))] self.val_losses = [[] for _ in range(len(loss_names.keys()))] with tqdm(total=int(len(self.train_generator)/self.cfg['train_batch_size']),desc=f'Epoch{epoch+1}/{num_epoch}',unit='img') as pbar: batch_count = 0 for batch in zip(train_Loader): batch_count +=1 if not cls_choice: if not compare_choice: loss = self.TrainSingle(batch)[0] else: loss = self.TrainSingleWithConstrain(batch)[0] else: loss = self.TrainSingleWithCls(batch,compare_choice,epoch,batch_count) self.optimizer.zero_grad() flood = abs(loss-flooding_val)+flooding_val flood.backward() epoch_loss += loss.item() self.optimizer.step() pbar.update() train_loss = (epoch_loss)/float(n_train) self.model.eval() val_loss = self.Val(val_Loader,epoch,num_epoch) self.model.train() self.scheduler.step(val_loss) self.logger.info("Train loss:{}".format(train_loss)) self.logger.info("Validation loss:{}".format(val_loss)) self.tensorboard_logger.log_value('train_loss', train_loss, epoch) self.tensorboard_logger.log_value('val_loss', val_loss, epoch) self._AddTensorData(self.train_losses,loss_names,'train',epoch) self._AddTensorData(self.val_losses,loss_names,'val',epoch) if self.cfg['save_cp']: model_save_path_current_epoch = self.save_path+f'/CP_epoch{epoch+1}_loss_%.2f.pth'%val_loss torch.save(self.model.state_dict(),model_save_path_current_epoch) self.logger.info(f'Checkpoint{epoch+1} saved to %s!'%model_save_path_current_epoch) def Val(self,val_Loader,epoch,num_epoch): compare_choice = self.cfg.get('compare_choice',False) cls_choice = self.common_cfg.get('do_cls',False) self.model.eval() total_loss = 0 n_val = len(self.val_generator) with tqdm(total=int(len(self.val_generator)/self.cfg['val_batch_size']),desc=f'Epoch{epoch+1}/{num_epoch}',unit='img') as pbar: batch_count = 0 for batch in zip(val_Loader): batch_count += 1 if not cls_choice: if not compare_choice: loss = self.TrainSingle(batch,mode='val')[0] else: loss = self.TrainSingleWithConstrain(batch,mode='val')[0] else: loss = self.TrainSingleWithCls(batch,compare_choice,epoch,batch_count,mode='val') total_loss += loss.item() pbar.update() self.model.train() return total_loss/float(n_val) def _getBatchData(self,mode='training'): batch_size = self.cfg.get('train_batch_size') current_generator = self.train_generator if mode=='training' else self.val_generator indices = np.arange(len(current_generator)) chosen_indices = np.random.choice(indices,size=batch_size,replace=False) imgs,masks,labels = [],[],[] for idx in chosen_indices: patch,mask,label = current_generator[idx] imgs.append(patch) masks.append(mask) labels.append(label) imgs = np.array(imgs) masks = np.array(masks) labels = np.array(labels) imgs = torch.from_numpy(imgs) masks = torch.from_numpy(masks) labels = torch.from_numpy(labels) batch = [imgs,masks,labels] return [batch] def _getBatchDataV2(self,mode='training'): batch_size = self.cfg.get('train_batch_size') current_generator = self.train_generator if mode=='training' else self.val_generator indices = np.arange(len(current_generator.cache)) chosen_indices = np.random.choice(indices,size=batch_size,replace=False) samples = np.take(current_generator.cache,chosen_indices,axis=0) imgs = [sample[0] for sample in samples] masks = [sample[1] for sample in samples] labels = [sample[2] for sample in samples] imgs = np.array(imgs) masks = np.array(masks) labels = np.array(labels) imgs = torch.from_numpy(imgs) masks = torch.from_numpy(masks) labels = torch.from_numpy(labels) batch = [imgs,masks,labels] return [batch] def TrainSingleWithCls(self,batch,compare_choice,epoch,batch_count,mode='training'): cache_size = self.cfg.get('cache_size',100) batch_size = self.cfg.get('train_batch_size') target_batch_count = int(cache_size/batch_size*0.8) ''' 执行分类任务 要对两个batch内的数据做 1.正负样本分类 2.相似性loss batchA:当前的batch batchB:从历史数据中抽取(维持一个cache) ''' if epoch>0 or batch_count>=target_batch_count: batch_compare = self._getBatchDataV2(mode) else: batch_compare = self._getBatchData(mode) if compare_choice: loss0,cls_pred0 = self.TrainSingleWithConstrain(batch,mode) loss_compare,cls_pred_compare = self.TrainSingleWithConstrain(batch_compare,mode) else: loss0,cls_pred0 = self.TrainSingle(batch,mode) loss_compare,cls_pred_compare = self.TrainSingle(batch_compare,mode) cls_label0 = batch[0][-1] cls_label_compare = batch_compare[0][-1] cls_label0 = cls_label0.cuda().to(device=self.device,dtype=torch.float32) cls_label_compare = cls_label_compare.cuda().to(device=self.device,dtype=torch.float32) total_loss = loss0+loss_compare loss_input = [[cls_pred0,cls_pred_compare]]+[[cls_label0,cls_label_compare]] cls_loss_val = self.cls_criterion(loss_input) cls_loss_val = self.common_cfg.get('cls_loss_ratio',0.1)*cls_loss_val if mode=='training': self.train_losses[-1].append(cls_loss_val.item()) else: self.val_losses[-1].append(cls_loss_val.item()) total_loss += cls_loss_val return total_loss def TrainSingleWithConstrain(self,batch,mode='training'): loss_weights = eval(self.cfg.get('loss_weights','[1.0,1.0,1.0]')) calculate_surface_loss = self.common_cfg.get('surface_loss',False) surface_loss_ratio = self.common_cfg.get('e_value',0.1) compare_loss_ratio = self.common_cfg.get('comapre_loss_ratio',0.1) cls_choice = self.common_cfg.get('do_cls',False) surface_criterion = SurfaceLoss() compare_loss = torch.nn. SmoothL1Loss() data,label,cls_label = batch[0] ''' 将原图进行变换 ''' trans_imgs,trans_labels,trans_dicts = [],[],[] for inner_batch_idx in range(data.shape[0]): trans_img,trans_label,trans_dict = self.train_generator._transImage(data[inner_batch_idx],label[inner_batch_idx]) trans_imgs.append(trans_img) trans_labels.append(trans_label) trans_dicts.append(trans_dict) ''' 计算surface loss ''' if calculate_surface_loss: if label.shape[0]>1: distance_map_label = calc_dist_map_batch(np.squeeze(label.cpu().numpy())) distance_map_label_trans = calc_dist_map_batch(np.squeeze(trans_labels)) else: distance_map_label = calc_dist_map(np.squeeze(label.cpu().numpy())) distance_map_label_trans = calc_dist_map(np.squeeze(trans_labels)) if len(distance_map_label.shape)==4: distance_map_label = distance_map_label[:,np.newaxis,...] distance_map_label_trans = distance_map_label_trans[:,np.newaxis,...] distance_map_label = torch.from_numpy(distance_map_label).to(device=self.device,dtype=torch.float32) distance_map_label_trans = torch.from_numpy(distance_map_label_trans).to(device=self.device,dtype=torch.float32) ''' 数据格式转换 ''' data = data.cuda().to(device=self.device,dtype=torch.float32) label = label.cuda().to(device=self.device,dtype=torch.float32) trans_imgs = np.array(trans_imgs) trans_labels = np.array(trans_labels) trans_imgs = torch.tensor(trans_imgs).cuda().to(device=self.device,dtype=torch.float32) trans_labels = torch.tensor(trans_labels).cuda().to(device=self.device,dtype=torch.float32) if mode=='val': with torch.no_grad(): pred = self.model(data) if cls_choice: cls_pred = pred[-1] pred = pred[:-1] else: cls_pred = None pred_trans = self.model(trans_imgs) else: pred = self.model(data) if cls_choice: cls_pred = pred[-1] pred = pred[:-1] else: cls_pred = None pred_trans = self.model(trans_imgs) cpu_pred = [val.data.cpu().numpy() for val in pred] cpu_pred_trans = [] ''' 对pred结果进行变换 ''' for branch_idx in range(len(pred)): current_branch_result = [] for inner_batch_idx in range(data.shape[0]): current_cpu_pred_trans = self.train_generator._transWithFunc(cpu_pred[branch_idx][inner_batch_idx],trans_dicts[inner_batch_idx]) current_branch_result.append(np.array(current_cpu_pred_trans).astype(np.double)) current_branch_result = np.array(current_branch_result) cpu_pred_trans.append(torch.tensor(current_branch_result).cuda().to(device=self.device,dtype=torch.float32)) gpu_pred_trans = cpu_pred_trans for idx in range(len(pred)): loss_val = self.criterion(pred[idx],label) loss_val_trans = self.criterion(pred_trans[idx],trans_labels) if mode=='val': self.val_losses[idx].append(loss_val.item()) self.val_losses[idx+len(pred)*2].append(loss_val.item()) else: self.train_losses[idx].append(loss_val_trans.item()) self.train_losses[idx+len(pred)*2].append(loss_val_trans.item()) if idx==0: total_loss = loss_val*loss_weights[idx]+loss_val_trans*loss_weights[idx] else: total_loss += (loss_val*loss_weights[idx]+loss_val_trans*loss_weights[idx]) if calculate_surface_loss: surface_loss = surface_criterion(pred[idx],distance_map_label) surface_loss_trans = surface_criterion(pred[idx],distance_map_label_trans) total_loss += surface_loss_ratio*(surface_loss+surface_loss_trans) if mode=='val': self.val_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) self.val_losses[idx+len(pred)*3].append(surface_loss_trans.item()*surface_loss_ratio) else: self.train_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) self.train_losses[idx+len(pred)*3].append(surface_loss_trans.item()*surface_loss_ratio) trans_loss = compare_loss(gpu_pred_trans[idx],pred_trans[idx]) if idx==len(pred)-1: total_loss += trans_loss * compare_loss_ratio if mode=='val': self.val_losses[idx+len(pred)*4].append(trans_loss.item() * compare_loss_ratio) else: self.train_losses[idx+len(pred)*4].append(trans_loss.item() * compare_loss_ratio) total_loss /= (2*len(pred)) return total_loss,cls_pred def TrainSingle(self,batch,mode='training'): loss_weights = eval(self.cfg.get('loss_weights','[1.0,1.0,1.0]')) val_mode = self.cfg.get('val_mode',1) cls_choice = self.common_cfg.get('do_cls',False) calculate_surface_loss = self.common_cfg.get('surface_loss',False) surface_loss_ratio = self.common_cfg.get('e_value',0.1) data,label,cls_label = batch[0][:3] surface_criterion = SurfaceLoss() if calculate_surface_loss: if label.shape[0]>1: distance_map_label = calc_dist_map_batch(np.squeeze(label.cpu().numpy())) else: distance_map_label = calc_dist_map(np.squeeze(label.cpu().numpy())) if len(distance_map_label.shape)==4: distance_map_label = distance_map_label[:,np.newaxis,...] distance_map_label = torch.from_numpy(distance_map_label).to(device=self.device,dtype=torch.float32) if mode=='val': with torch.no_grad(): if val_mode==1: data = data.cuda().to(device=self.device,dtype=torch.float32) label = label.cuda().to(device=self.device,dtype=torch.float32) pred = self.model(data) else: label = torch.from_numpy(cp_label).cuda().to(device=self.device,dtype=torch.float32) new_data = [single_data.unsqueeze(0).cuda().to(device=self.device,dtype=torch.float32) for single_data in data] pred_list = [[] for _ in range(3)] for single_data in new_data: result = self.model(single_data) for inner_idx in range(len(pred_list)): pred_list[inner_idx].append(result[inner_idx]) pred = [] for inner_idx in range(len(pred_list)): pred.append(torch.cat(pred_list[inner_idx],0)) else: pred = self.model(data) total_loss = 0 if cls_choice: cls_pred = pred[-1] pred = pred[:-1] else: cls_pred = None for idx in range(len(pred)): loss_val = self.criterion(pred[idx],label) if mode=='val': self.val_losses[idx].append(loss_val.item()) else: self.train_losses[idx].append(loss_val.item()) if idx==0: total_loss = loss_val*loss_weights[idx] else: total_loss += loss_val*loss_weights[idx] if calculate_surface_loss: surface_loss = surface_criterion(pred[idx],distance_map_label) total_loss += surface_loss_ratio*surface_loss if mode=='val': self.val_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) else: self.train_losses[idx+len(pred)].append(surface_loss.item()*surface_loss_ratio) total_loss/=len(pred) return total_loss,cls_pred def compileModel(self): union_choice = self.cfg.get('union_choice',False) lr_schedule = self.cfg.get('lr_schedule','basic') self.learn_feature = self.cfg.get('learn_feature',False) self.optimizer = optim.Adam(self.model.parameters(),lr=self.common_cfg['initial_learning_rate']) if lr_schedule=='basic': self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=self.common_cfg['learning_rate_patience'], factor=self.common_cfg['learning_rate_drop'],verbose=True) else: self.scheduler = optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.common_cfg.get('base_lr',1e-5), max_lr=self.common_cfg.get('max_lr',1e-4), step_size_up=self.common_cfg.get('step_size_up',2000), cycle_momentum=False) if self.common_cfg.get('do_cls',False): self.cls_criterion = SeLossWithBCEFocalLoss() if union_choice: # self.criterion = DiceLoss_TwoLabels(alpha=self.common_cfg['T_loss_alpha'],beta=1-self.common_cfg['T_loss_alpha']) self.criterion = DiceLoss_TwoLabels(alpha=0.1,beta=1.0) else: self.criterion = Tversky_loss(alpha=self.common_cfg['T_loss_alpha'],beta=1-self.common_cfg['T_loss_alpha']) def buildModel(self): filename = self.common_cfg['model_parameters'] target_model_cfg_name = os.path.join(self.save_path,filename.split('/')[-1]) filename = 'files/%s'%filename model_cfg = LoadJson(filename) shutil.copy(filename,target_model_cfg_name) encoder_model_name = self.common_cfg['encoder'] decoder_model_name = self.common_cfg['decoder'] encoder_model_cfg = model_cfg[encoder_model_name] decoder_model_cfg = model_cfg[decoder_model_name] encoder_model_cfg['do_cls'] = False encoder_model = build_model(encoder_model_name,encoder_model_cfg) decoder_model_cfg['encoder_model'] = encoder_model decoder_model_cfg['do_cls'] = self.common_cfg.get('do_cls',False) self.model = build_model(decoder_model_name,decoder_model_cfg) self.model.apply(weight_init) def MakeSinglePred(self,idx,model_order=0,model_path=None): infer_idx = self.cfg.get('infer_idx',-1) th = self.cfg.get('th',0.05) respacing_back = self.cfg.get('respacing_back',False) print ('='*60) ''' 读入数据 ''' output_img_patches,output_mask_patches,pad_mask,pad_vals,bounds,image,lungmask,mask_path = self.test_generator[idx] uid = mask_path.split('/')[-1][:-len('.nii.gz')] predictions = [] ''' 逐patch进行infer ''' for idx in range(len(output_img_patches)): data = torch.from_numpy(output_img_patches[idx][np.newaxis,np.newaxis,...]).to(device=self.device,dtype=torch.float32) pred = self.model(data) result_before_sigmoid = [val.data.cpu().numpy() for val in pred][infer_idx] pred = [torch.sigmoid(val) for val in pred] pred = [val.data.cpu().numpy() for val in pred] pred = pred[infer_idx] predictions.append(pred) ''' combine回原图大小 ''' result,mask = self.test_generator._CombinePrediction(predictions,pad_mask,pad_vals) result,mask = self.test_generator._padBackToOriSize(image,result,mask,pad_vals,bounds) ''' 二值化 ''' raw_result = np.array(result>th).astype(np.uint8) '''' 后处理:去掉小的连通域(目前线上模型没有用到&dice变换不大) ''' raw_result_post = self._PostProcess(raw_result) ''' 去掉肺外区域 ''' result = lungmask*result th_result = np.array(result>=th).astype(np.uint8) '''计算dice''' dice_val = dice_coefficient_post(mask,th_result) dice_raw = dice_coefficient_post(mask,raw_result) dice_post = dice_coefficient_post(mask,raw_result_post) self.post_dice.append(dice_post) ''' 如果必要,将mask respacing到原图大小 ''' if respacing_back: img_path = sorted(glob.glob('/fileser/CT_COVID/IMAGE/RAW/%s*'%uid))[0] lungmask_path = sorted(glob.glob('/fileser/CT_COVID/LUNGMASK/RAW/%s*'%uid))[0] img_ori = sitk.GetArrayFromImage(sitk.ReadImage(img_path)) lungmask_ori = sitk.GetArrayFromImage(sitk.ReadImage(lungmask_path)) ori_spac_mask_itk = sitk.ReadImage(lungmask_path) ori_spac_mask = sitk.GetArrayFromImage(ori_spac_mask_itk) lungmask_ori = ori_spac_mask spac = ori_spac_mask_itk.GetSpacing() result_itk = sitk.GetImageFromArray(result) result_itk.SetSpacing((1,1,1)) ori_spac_result_itk = NearestResample(result_itk,spac,img_ori.shape[::-1]) ori_spac_result = sitk.GetArrayFromImage(ori_spac_result_itk) ori_spac_result = np.array(ori_spac_result>th).astype(np.uint8) ################################# try: dice_val_ori = dice_coefficient_post(ori_spac_mask,ori_spac_result) ''' 计算原图dice并存储 ''' self.ori_spac_dice.append(dice_val_ori) base_save_path = self.cfg.get('base_save_path','') current_case_save_path = '%s/%s_dice%.2f.nii.gz'%(base_save_path,uid,dice_val_ori*100) itkWriter = sitk.ImageFileWriter() itkWriter.SetUseCompression(True) itkWriter.SetFileName(current_case_save_path) itkWriter.Execute(ori_spac_result_itk) except Exception as err: print ('err ',err) print ('no valid orisapcing result') else: ''' 存储infer结果 ''' self._WriteInferResult(uid,th_result,dice_val,model_order,model_path) self.no_lungmask_dice.append(dice_raw) print ('dice_val and dice_raw and dice_post is',dice_val,dice_raw,dice_post) ''' badcase可视化 ''' if dice_val0.05).astype(np.uint8) if np.sum(current_result_slice)==0 and np.sum(mask[slice_idx])==0: continue slice_dice_val = dice_coefficient_post(mask[slice_idx],current_result_slice) print ('slice_dice_val is',slice_dice_val) figsize=15 _,axs = plt.subplots(2,2,figsize=(figsize,figsize)) axs[1][0].imshow(current_result_slice,cmap='bone') axs[1][1].imshow(mask[slice_idx],cmap='bone') axs[0][0].imshow(image[slice_idx],cmap='bone') axs[0][1].imshow(lungmask[slice_idx],cmap='bone') plt.show() return image,mask,result def _WriteInferResult(self,uid,result,dice,model_order=0,model_path=None): base_save_path = self.cfg.get('base_save_path','') base_save_path = '%s/%02d/'%(base_save_path,model_order) if not os.path.exists(base_save_path): os.makedirs(base_save_path) if model_path is not None: model_name = model_path.split('/')[-1] target_model_path = '%s/%s'%(base_save_path,model_name) if not os.path.exists(target_model_path): shutil.copy(model_path,target_model_path) if os.path.exists(base_save_path): current_case_save_path = '%s/%s_dice%.2f.nii.gz'%(base_save_path,uid,dice*100) itkWriter = sitk.ImageFileWriter() itkWriter.SetUseCompression(True) itkWriter.SetFileName(current_case_save_path) itk_result = sitk.GetImageFromArray(result) itk_result.SetSpacing((1,1,1)) itkWriter.Execute(itk_result) def _AddTensorData(self,losses,name_dict,mode,epoch): for idx in range(len(losses)): if len(losses[idx])>0: mean_val = np.mean(losses[idx]) self.tensorboard_logger.log_value('%s_%s'%(mode,name_dict[idx]), mean_val, epoch) def _PostProcess(self,mask): region_th = 100 from skimage.measure import label,regionprops final_mask = np.zeros_like(mask) label_mask = label(mask) for region in regionprops(label_mask): if region.area0).astype(np.uint8) return np.array(final_mask>0).astype(np.uint8) def _featureMapVis(self,x,layer_idx): print ('#'*60) print ('Feature Map Vis') print ('#'*60) temp_layers = list(self.model.children()) net = nn.Sequential(*list(self.model.children())[:layer_idx]) net.eval() out = net(x) out_cpu = [val.data.cpu().numpy() for val in out] return out_cpu def _normFeatureMap(self,FeatureMap): FeatureMap = np.squeeze(np.array(FeatureMap)) ChannelSums = [np.sum(val) for val in FeatureMap] totalSum = np.sum(ChannelSums) normSums = [val/float(totalSum) for val in ChannelSums] result = [val1*val2 for val1,val2 in zip(normSums,FeatureMap)] return result def MakeInference(self): th = 0.05 pretrain = self.cfg['pre_model_file'] if '*' in pretrain: cand_paths = sorted(glob.glob(pretrain+'*pth')) else: cand_paths = [pretrain] model_dice_vals = [] for model_order,model_path in enumerate(cand_paths): print ('#'*60) self.loadModel(model_path) save_result_chocie = self.cfg.get('save_result_choice',False) self.model.eval() indices = np.arange(len(self.test_generator)) dice_vals = [] image_list,mask_list,pred_list = [],[],[] self.ori_spac_dice = [] self.no_lungmask_dice = [] self.post_dice = [] label_sums = [] for idx in indices: image,mask,result = self.MakeSinglePred(idx,model_order,model_path) image_list.append(image) mask_list.append(mask) pred_list.append(result) result = np.array(result>=th).astype(np.uint8) mask = np.array(mask>0).astype(np.uint8) label_sums.append(np.sum(mask)) print ('sum of mask is %d and sum of pred is %d'%(np.sum(mask),np.sum(result))) dice_val = dice_coefficient_post(mask,result) dice_vals.append(dice_val) self.logger.info("mean val of dice vals is %.4f"%(np.mean(dice_vals))) model_dice_vals.append(np.mean(dice_vals)) if len(self.ori_spac_dice)>0: self.logger.info("mean val of orispacing dice vals is %.4f"%(np.mean(self.ori_spac_dice))) self.logger.info("mean val of (without lung mask) dice vals is %.4f"%(np.mean(self.no_lungmask_dice))) self.logger.info("mean val of (postprocessing) dice vals is %.4f"%(np.mean(self.post_dice))) self.logger.info('model path is %s'%model_path) if len(model_dice_vals)>0: max_dice_idx = np.argmax(model_dice_vals) max_mean_dice = max(model_dice_vals) max_dice_model_path = cand_paths[max_dice_idx] self.logger.info('Max dice is %.4f and its model path is %s'%(max_mean_dice,max_dice_model_path))