# -*- coding:utf-8 -*- import os import sys import six import glob import json import random import shutil import datetime import numpy as np import pandas as pd import scipy.ndimage as nd from tqdm import tqdm import datetime import logging import torch import torch.nn as nn from torch import optim from matplotlib import pyplot as plt from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ # from torch.utils.tensorboard import SummaryWriter from tensorboard_logger import Logger sys.path.append('./utils') sys.path.append('./model') from RWconfig import LoadJson from ReadData import load_data,load_single_data from modelBuild import build_model from BasicModules import weight_init from PneuGen import PneuDataSet from OnlineEval import eval_net,eval_netSE,eval_netFSE from OfflineEval import CalculateClsScore,CalculateAuc,CalculateClsScoreByTh,PlotRoc from loss_func import * logger = logging.getLogger() fh = logging.FileHandler('SimpleCls.log',encoding='utf-8') sh = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') fh.setFormatter(formatter) sh.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(sh) logger.setLevel(10) class Classification3D(object): def __init__(self,filename,mode): self.mode = mode self.filename = filename self.cfg = LoadJson(self.filename)[mode] self.basic_file_path = './files' current_DT = datetime.datetime.now() self.current_day = '%02d%02d'%(current_DT.month,current_DT.day) self.n_input_channel = 1 self._clsInfoDecode() self.ParametersDecode() self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def __call__(self): ''' 训练/infer流程 1. buildModel&compileModel 2. 载入数据 3. 根据选择模式,进行训练or inference ''' self._makeTensorBoard() self._buildModel() self._compileModel() self._LoadData(self.mode) if self.mode=='training': self._trainModel() else: self._testModel() def _makeTensorBoard(self): tensorboard_path = "%s/tensorboard_logs"%self.save_path self.tensorboard_logger = Logger(logdir=tensorboard_path,flush_secs=10) logger.info('tensorboard path is %s'%tensorboard_path) def _clsInfoDecode(self): self.cls_name_map_dict = eval(self.cfg['cls_name_map_dict']) model_order = self.cfg['model_order'] if 'model_order' in self.cfg.keys() else 1 if type(self.cfg['cls_label_index'])==int: cls_related_info = self.cls_name_map_dict[self.cfg['cls_label_index']] else: cls_related_info = '_'.join([self.cls_name_map_dict[val] for val in self.cfg['cls_label_index']]) self.save_path = '%s/%s_%s_%dCls_%s_model%s/%02d_%s/'%(self.cfg['base_path'],self.cfg['weights_pre'],cls_related_info,self.cfg['num_class'],self.current_day,self.cfg['model_name'],model_order,self.cfg['weight_memo']) if not os.path.exists(self.save_path): os.makedirs(self.save_path) if type(self.cfg['cls_label_index'])==list: self.cfg['num_task'] = len(self.cfg['cls_label_index']) else: self.cfg['num_task'] = 1 def _buildModel(self): ''' buildModel 流程 1. 读取模型参数并修改部分 1. num_task:有几个要分类的任务 2. num_class:有几类(目前多个要分类的任务类别数目必须相同) 3. arcface_norm_choice:arcface是否进行归一化 4. easy_margin: arcface相关参数 如果设置为True,只要cos_theta大于0,就设置为乘性压缩的角度 否则,要cos_theta大于th,设置为乘性压缩角度,cos_theta小于th的情况,设置为减性压缩角度 5. freeze_blocks:不训练的block ''' cfg = self.model_params cfg['num_task'] = self.cfg['num_task'] cfg['num_class'] = self.cfg['num_class'] cfg['freeze_blocks'] = eval(self.cfg.get('freeze_blocks',"[]")) if cfg['num_class'] == 2: cfg['num_class'] =1 self.n_input_channel = cfg['n_input_channel'] if 'n_input_channel' in cfg.keys() else self.n_input_channel self.net = build_model(self.cfg['model_name'],cfg) self.net_config = cfg self.pretrain_weight_path = self.cfg.get('weight_path','') if self.pretrain_weight_path =='': self.net.apply(weight_init) def _compileModel(self): self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()),lr=self.cfg['learning_rate']) lr_schedule = self.cfg.get('lr_schedule','basic') if lr_schedule=='basic': self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=self.cfg['patience'], factor=0.5,verbose=True) else: self.scheduler = optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.cfg.get('base_lr',1e-5), max_lr=self.cfg.get('max_lr',1e-4), step_size_up=self.cfg.get('step_size_up',2000), cycle_momentum=False) if self.cfg['num_class']>2: self.criterion = nn.CrossEntropyLoss() else: self.criterion = nn.BCEWithLogitsLoss() conf_paths = [self.filename,self.model_param_name] # if self.mode == 'training': # conf_paths += [self.train_data_files,self.val_data_files] # else: # conf_paths += [self.test_data_files] for conf_path in conf_paths: filename = conf_path.split('/')[-1] target_save_path = os.path.join(self.save_path,filename) shutil.copy(conf_path,target_save_path) def _DecodeGeneratorParam(self): training_mode = self.cfg.get('training_mode',2) nonfix_crop = self.cfg['nonfix_crop'] if 'nonfix_crop' in self.cfg.keys() else False aug = True if self.mode=='training' else False model_mode = self.cfg['model_mode'] if 'model_mode' in self.cfg.keys() else 'basic' patch_depth = self.cfg['patch_depth'] patch_height = self.cfg['patch_height'] patch_width = self.cfg['patch_width'] patch_shape = [patch_depth,patch_height,patch_width] self.generator_parameters = { 'cls_index':self.cfg['cls_label_index'], 'aug':aug, 'num_class':self.cfg['num_class'], 'HU_min':self.cfg['HU_min'],'HU_max':self.cfg['HU_max'], 'n_input_channel':self.n_input_channel, "model_mode":model_mode, "patch_shape":patch_shape, "training_mode":training_mode, "mode":self.mode } def _DecodeDataParam(self,filename): paths = LoadJson(filename) data_paths = [record['data_path'] for record in paths] info_paths = [record['info_path'] for record in paths] return {'data_paths':data_paths, 'info_paths':info_paths} def _parseDataJosn(self,filename): df_files = [] target_filename = os.path.join(self.save_path,filename.split('/')[-1]) filename = 'files/%s'%filename files = LoadJson(filename) for record in files: df_file = record["df_path"] df_files.append(df_file) return df_files def _LoadData(self,mode): ''' 载入npy格式的数据 ''' train_data_files = self.cfg["train_file"] val_data_files = self.cfg["val_file"] test_data_files = self.cfg["test_file"] self.train_paths = self._parseDataJosn(train_data_files) self.val_paths = self._parseDataJosn(val_data_files) self.test_paths = self._parseDataJosn(test_data_files) self.train_data_files = train_data_files self.val_data_files = val_data_files self.test_data_files = test_data_files def _GetLoss(self,prob_diff,gamma): loss_choice = self.cfg['loss_func'] if 'loss_func' in self.cfg.keys() else 'focal' ''' 基于focal loss基本的weights(1/该类样本量->归一化),是否要对某一类的weights进行修改,如果否就设置为None,否则设置为对应的loss ''' alpha_multi_ratio = eval(self.cfg['alpha_multi_ratio']) if 'alpha_multi_ratio' in self.cfg.keys() else None self.convert_softmax = True if self.cfg['num_class']==1: ratios =self.train_dataset._GenerateRatios() alpha = ratios[0] self.criterion = BCEFocalLoss(alpha=alpha,gamma=gamma) if self.cfg.get('model_mode','basic')=='deux': self.criterion_train = SeLossWithBCEFocalLoss(lambda_val=self.cfg.get('lambda_val',1)) else: ratios =self.train_dataset._GenerateRatios() self.criterion = FocalLoss(self.cfg['num_class'],alpha=ratios,gamma=gamma,alpha_multi_ratio=alpha_multi_ratio,convert_softmax=self.convert_softmax) print ('criterion is',self.criterion) def _trainOnBatch(self,batch,mode='train'): self.loss_weights = self.cfg.get('loss_weights',[1.0 for _ in range(3)]) label_dtype = torch.long if self.cfg['num_class']>2 else torch.float32 imgs,labels,_ = batch[0] imgs = imgs labels = labels imgs = np.transpose(imgs,(1,0,2,3,4,5)) imgs = imgs.cuda().to(device=self.device,dtype=torch.float32) output_labels = [] pred_count = 0 for single_label in labels: current_label = [] pred_count = len(single_label) for idx in range(len(single_label)): current_label.append(np.array(single_label[idx])) current_label = np.array(current_label)[...,np.newaxis] current_label = torch.from_numpy(current_label).cuda().to(device=self.device,dtype=torch.float32) output_labels.append(current_label) labels = output_labels if mode=='val': with torch.no_grad(): prediction = self.net(x1=imgs[0],x2=imgs[1]) else: prediction = self.net(x1=imgs[0],x2=imgs[1]) loss = 0 if type(prediction)!=list and type(prediction)!=tuple: prediction = [prediction] for idx in range(min(len(labels),len(prediction))): current_loss = self.criterion([prediction[idx],labels[idx]]) loss += current_loss if mode=='val': self.val_loss[idx].append(current_loss.item()) else: self.train_loss[idx].append(current_loss.item()) return loss def _trainOnBatchV2(self,batch,mode='train'): label_dtype = torch.long if self.cfg['num_class']>2 else torch.float32 imgs,labels,_ = batch[0] imgs = imgs[0] labels = labels[0] imgs = imgs.cuda().to(device=self.device,dtype=torch.float32) labels = labels.cuda().to(device=self.device,dtype=torch.float32) if mode=='val': with torch.no_grad(): prediction = self.net(imgs) else: prediction = self.net(imgs) loss = self.criterion([prediction,labels]) idx = 0 if mode=='val': self.val_loss[idx].append(current_loss.item()) else: self.train_loss[idx].append(current_loss.item()) return loss def _trainOnBatchWithCons(self,batch,mode='train'): label_dtype = torch.long if self.cfg['num_class']>2 else torch.float32 imgs,labels,_ = batch[0] imgs = imgs[0] labels = labels[0] indices = np.arange(imgs.shape[0])[::-1] cpu_labels = labels.cpu().numpy() cpu_labels = np.take(cpu_labels,indices,axis=0) cpu_labels = cpu_labels.copy() imgs = imgs.cuda().to(device=self.device,dtype=torch.float32) labels = labels.cuda().to(device=self.device,dtype=torch.float32) cpu_labels = torch.tensor(cpu_labels).cuda().to(device=self.device,dtype=torch.float32) if mode=='val': with torch.no_grad(): prediction = self.net(imgs) else: prediction = self.net(imgs) prediction_cpu = prediction.detach().cpu().numpy() prediction_reverse = np.take(prediction_cpu,indices,axis=0) prediction_reverse = torch.tensor(prediction_reverse).cuda().to(device=self.device,dtype=torch.float32) loss_input = [[prediction,prediction_reverse],[labels,cpu_labels],[]] loss = self.criterion_train(loss_input) return loss def _trainModel(self): ########## Add parameters training_mode = self.cfg.get('training_mode',2) ''' 设置训练的device ''' self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ''' 如果pretrain_weight_path不为空,则load进来 ''' if self.pretrain_weight_path!='': self.net.load_state_dict(torch.load(self.pretrain_weight_path,map_location=device)) logger.info("load model from %s"%self.pretrain_weight_path) ''' 读取关于loss和数据处理的参数 nonfix_crop:输入的大小是否一致 train_aug_choice:train的数据是否进行augmentation gamma:focal loss参数 feature_sim: prob_diff: ''' train_aug_choice = self.cfg['train_aug_choice'] if 'train_aug_choice' in self.cfg.keys() else False gamma = self.cfg['gamma'] if 'gamma' in self.cfg.keys() else 1 prob_diff = self.cfg['prob_diff'] if 'prob_diff' in self.cfg.keys() else False ############ generate params for generator self._DecodeGeneratorParam() train_dataset = PneuDataSet(self.train_paths,**self.generator_parameters) self.generator_parameters['aug'] = False val_dataset = PneuDataSet(self.val_paths,**self.generator_parameters) self.train_dataset = train_dataset self.val_dataset = val_dataset ''' 根据参数情况得到训练的loss ''' self._GetLoss(prob_diff,gamma) train_loader = DataLoader(train_dataset,batch_size=self.cfg['batch_size'],shuffle=True) val_loader = DataLoader(val_dataset,batch_size=self.cfg['batch_size'],shuffle=True) self.val_loader = val_loader lr=self.cfg['learning_rate'] batch_size = self.cfg['batch_size'] global_step = 0 epochs = self.cfg['num_epoch'] n_train = int(round(len(train_dataset)/self.cfg['batch_size'])) loss_names = { 0:'left_loss', 1:'right_loss', 2:'whole_loss' } for epoch in range(epochs): self.net.train() self.net.to(self.device) self.net.cuda() epoch_loss = 0 self.train_loss = [[] for _ in range(3)] self.val_loss = [[] for _ in range(3)] with tqdm(total=n_train,desc=f'Epoch{epoch+1}/{epochs}',unit='img') as pbar: batch_count = 0 for batch in zip(train_loader): try: batch_count +=1 loss_values = [] if self.cfg.get("model_mode","basic")=="deux": loss = self._trainOnBatchWithCons(batch)/2.0 else: if training_mode ==1: loss = self._trainOnBatch(batch) else: loss = self._trainOnBatchV2(batch) self.optimizer.zero_grad() loss.backward() epoch_loss += loss.item() nn.utils.clip_grad_value_(self.net.parameters(),1e-1) self.optimizer.step() global_step += 1 pbar.update() except Exception as err: print ('err is',err) continue random.shuffle(train_dataset.indices) val_score = self.Val(val_loader,epoch,epochs) ############ change lr based on val result self.scheduler.step(val_score) logger.info("Train loss:{}".format(epoch_loss/n_train)) logging.info("Validation loss:{}".format(val_score)) self.tensorboard_logger.log_value('train_loss', epoch_loss/n_train, epoch) self.tensorboard_logger.log_value('val_loss', val_score, epoch) self._AddTensorData(self.train_loss,loss_names,'train',epoch) self._AddTensorData(self.val_loss,loss_names,'val',epoch) if self.cfg['save_cp']: model_save_path_current_epoch = self.save_path+f'/CP_epoch{epoch+1}_loss_%.2f.pth'%val_score torch.save(self.net.state_dict(),model_save_path_current_epoch) logging.info(f'Checkpoint{epoch+1} saved to %s!'%model_save_path_current_epoch) def _AddTensorData(self,losses,name_dict,mode,epoch): for idx in range(len(losses)): if len(losses[idx])>0: mean_val = np.mean(losses[idx]) self.tensorboard_logger.log_value('%s_%s'%(mode,name_dict[idx]), mean_val, epoch) def Val(self,val_Loader,epoch,num_epoch): training_mode = self.cfg.get('training_mode',2) self.net.eval() total_loss = 0 n_val = len(self.val_dataset)/self.cfg['batch_size'] with tqdm(total=int(n_val),desc=f'Epoch{epoch+1}/{num_epoch}',unit='img') as pbar: batch_count = 0 for batch in zip(val_Loader): batch_count += 1 if training_mode ==1: loss = self._trainOnBatch(batch,mode='val') else: loss = self._trainOnBatchV2(batch,mode='val') total_loss += loss.item() pbar.update() self.net.train() return total_loss/float(n_val) def _testModel(self): training_mode = self.cfg.get('training_mode',2) if '*' in self.cfg['weight_path']: model_paths = sorted(glob.glob(self.cfg['weight_path'])) model_paths = [path for path in model_paths if '.pth' in path] else: model_paths = [self.cfg['weight_path']] results = [] result = [] self._DecodeGeneratorParam() infer_idx = -1 self.convert_softmax = True test_dataset = PneuDataSet(self.test_paths,**self.generator_parameters) for model_path in model_paths: print ('='*60) print ('model_path is ',model_path) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.net.to(device=device) self.net.load_state_dict(torch.load(model_path,map_location=device)) logging.info('Model loaded !') ########## 将模型设置为eval 固定BN/dropout的参数 self.net.eval() self.label_list = [] self.result_list = [] pbar = tqdm(np.arange(len(test_dataset))) for idx in pbar: batch = test_dataset[idx] imgs,labels,label_diff = batch if training_mode==1: imgs = [torch.from_numpy(img[np.newaxis,...]).to(device=self.device,dtype=torch.float32) for img in imgs] labels = labels elif training_mode==2: imgs = torch.from_numpy(imgs[0][np.newaxis,...]).to(device=self.device,dtype=torch.float32) labels = labels[0] else: imgs = torch.from_numpy(imgs).to(device=self.device,dtype=torch.float32) labels = labels if training_mode==1: label = labels[infer_idx] else: label = labels with torch.no_grad(): if len(imgs)==2: output = self.net(imgs[0],imgs[1]) output = [single_output.squeeze(0) for single_output in output] else: output = self.net(imgs) if self.cfg['num_task']==1: if self.convert_softmax: if training_mode==1: if self.net.num_class>1: probs = torch.softmax(output[infer_idx],dim=1) else: probs = torch.sigmoid(output[infer_idx]) elif training_mode==2: if self.net.num_class>1: probs = torch.softmax(output[0],dim=1) else: probs = torch.sigmoid(output[0]) else: if self.net.num_class>1: probs = torch.softmax(output,dim=1) else: probs = torch.sigmoid(output) probs = probs.max() # probs = probs.mean() probs = probs.squeeze(0) probs = probs.data.cpu().numpy() else: current_probs = [] for task_idx in range(self.cfg['num_task']): if self.net.num_class>1: probs = torch.softmax(output[task_idx],dim=1) else: probs = torch.sigmoid(output[task_idx]) probs = probs.squeeze(0) current_probs.append(probs.cpu().numpy()[0]) probs = current_probs self.label_list.append(label) self.result_list.append(probs) show_prob = self.cfg['show_prob'] if 'show_prob' in self.cfg.keys() else False if show_prob: target_cls_probs = self._GetTargetClassProb(self.label_list,self.result_list) self.ProbVis(self.label_list,target_cls_probs) print ('length of label and result',len(self.label_list),len(self.result_list)) print ('shape of result list is ',np.array(self.result_list).shape) print ('max val of self.result_list',np.amin(self.result_list),np.amax(self.result_list)) if self.cfg['num_class']<=2: if self.cfg['num_task']==1: th_result = [val>0.5 for val in self.result_list] result = th_result CalculateClsScore(th_result,self.label_list) auc_val = CalculateAuc(self.result_list,self.label_list) # PlotRoc(self.result_list,self.label_list) results.append(auc_val) if len(model_paths)==1: CalculateClsScoreByTh(self.result_list,self.label_list,acc_flag=True) else: for task_idx in range(self.cfg['num_task']): print ('='*60) print ('Current task is %s'%self.cls_name_map_dict[self.cfg['cls_label_index'][task_idx]]) current_label_list = np.squeeze(np.array([val[task_idx] for val in self.label_list])) current_result_list = np.squeeze(np.array([val[task_idx] for val in self.result_list])) th_result = np.squeeze(np.array([val>0.5 for val in current_result_list])) print (np.array(current_label_list).shape,np.array(current_result_list).shape,np.array(th_result).shape) CalculateClsScore(th_result,current_label_list) auc_val = CalculateAuc(current_result_list,current_label_list) else: result = np.squeeze(np.array(self.result_list)) result = np.argmax(result,axis=1) print ('shape of result is',result.shape) print ('shape of self.label_list is',np.array(self.label_list).shape) kappa_val = CalculateClsScore(result,self.label_list) results.append(kappa_val) if len(results)>0: print ('results is ',results) pos = np.argmax(results) print ('best weights is ',model_paths[pos]) self.test_dataset = test_dataset self.result = result self.badcase_lists = [case_id for case_id in range(len(test_dataset)) if result[case_id]!=self.label_list[case_id]] def _GetTargetClassProb(self,labels,probs): target_class_probs = [prob[0][label] for prob,label in zip(probs,labels)] return target_class_probs def ProbVis(self,labels,probs): values = np.unique(labels) from matplotlib import pyplot as plt for val in values: print ('='*60) print ('Current label is %d'%val) indices = [idx for idx in range(len(labels)) if labels[idx]==val] current_probs = np.take(probs,indices,axis=0) print ('Number of samples are %d'%len(current_probs)) plt.hist(current_probs,range=(0,1)) plt.show() plt.show() def BadCaseVis(self,num_incies,show_edges=False,gaussian_choice=False,sigma=1): for idx in self.badcase_lists[:num_incies]: current_data = self.test_dataset[idx] img = np.squeeze(current_data[0]) label = current_data[1] center_slice_idx = int(img.shape[0]/2) slice_range = 1 print ('='*60) print ('Current label is ',label,' result is ',self.result[idx]) print ('prob is',self.result_list[idx]) if not show_edges: _,axs =plt.subplots(1,2*slice_range+1) else: _,axs = plt.subplots(2,2*slice_range+1) for slice_idx in range(center_slice_idx-slice_range,center_slice_idx+slice_range+1): if not show_edges: axs[slice_idx-(center_slice_idx-slice_range)].imshow(img[slice_idx],cmap='bone') else: edges = ShowEdges(img[slice_idx],gaussian_choice=gaussian_choice,sigma=sigma) # edges = HFFilter(img[slice_idx]) axs[0][slice_idx-(center_slice_idx-slice_range)].imshow(img[slice_idx],cmap='bone') axs[1][slice_idx-(center_slice_idx-slice_range)].imshow(edges,cmap='bone') plt.show() def ParametersDecode(self): ''' Generate model parameters ''' self.model_param_name = os.path.join(self.basic_file_path,self.cfg['model_params']) self.model_params = LoadJson(self.model_param_name)[self.cfg['model_name']]