# -*- coding:utf-8 -*- import os import sys import six import glob import json import random import shutil import datetime import numpy as np import pandas as pd import scipy.ndimage as nd from tqdm import tqdm import datetime import logging import torch import torch.nn as nn from torch import optim from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ # from matplotlib import pyplot as plt from tensorboardX import SummaryWriter sys.path.append('./utils') sys.path.append('./model') from .utils.RWconfig import LoadJson from .utils.ReadData import load_data,load_single_data from .model.modelBuild import build_model from .model.BasicModules import weight_init from .utils.DataGen import SurDataSet from .utils.OnlineEval import eval_net from .utils.OfflineEval import CalculateClsScore,CalculateAuc,CalculateClsScoreByTh from .utils.loss_func import BCEFocalLoss,PeerLoss,MIFocalLoss,FocalLoss from .utils.gradcam import GradCam,GuidedBackpropReLUModel logger = logging.getLogger() fh = logging.FileHandler('Train.log',encoding='utf-8') sh = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') fh.setFormatter(formatter) sh.setFormatter(formatter) logger.addHandler(fh) logger.addHandler(sh) logger.setLevel(10) class Classification3D(object): def __init__(self,filename,mode): self.mode = mode self.filename = filename self.cfg = LoadJson(self.filename)[mode] self.basic_file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) current_DT = datetime.datetime.now() self.current_day = '%02d%02d'%(current_DT.month,current_DT.day) self.convert_dict = self.cfg['convert_dict'] if 'convert_dict' in self.cfg.keys() else {} self._clsInfoDecode() writer_path = os.path.join(self.save_path, 'runs') train_writer_path = os.path.join(writer_path, 'train') val_writer_path = os.path.join(writer_path, 'val') self.writer = SummaryWriter(writer_path) if self.mode=='training' else None self.train_writer = SummaryWriter(train_writer_path) if self.mode=='training' else None self.val_writer = SummaryWriter(val_writer_path) if self.mode=='training' else None if mode=='training': self.trainfile_conf_path = os.path.join(self.basic_file_path,self.cfg['train_file']) self.valfile_conf_path = os.path.join(self.basic_file_path,self.cfg['val_file']) self.trainfiles = self._DecodeDataParam(self.trainfile_conf_path) self.valfiles = self._DecodeDataParam(self.valfile_conf_path) self.testfiles = [] else: self.trainfiles = [] self.valfiles = [] self.test_conf_path = os.path.join(self.basic_file_path,self.cfg['test_file']) self.testfiles = self._DecodeDataParam(self.test_conf_path) self.ParametersDecode() self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def __call__(self): self._buildModel() self._compileModel() self._LoadData(self.mode) if self.mode=='training': self._trainModel() else: import time start_time = time.time() self._testModel() end_time = time.time() dur = end_time - start_time print ('========================================== It takes %s s to make inference'%str(dur)) if self.writer: self.writer.close() if self.train_writer: self.train_writer.close() if self.val_writer: self.val_writer.close() def _clsInfoDecode(self): self.cls_map_dict = eval(self.cfg['cls_map_dict']) self.ori_cls_types = set([self.cls_map_dict[key] for key in self.cls_map_dict.keys()]) self.input_shape = [self.cfg['patch_height'],self.cfg['patch_height'],self.cfg['patch_depth']] self.cls_name_map_dict = eval(self.cfg['cls_name_map_dict']) model_order = self.cfg['model_order'] if 'model_order' in self.cfg.keys() else 1 if type(self.cfg['cls_label_index'])==int: cls_related_info = self.cls_name_map_dict[self.cfg['cls_label_index']] else: cls_related_info = '_'.join([self.cls_name_map_dict[val] for val in self.cfg['cls_label_index']]) self.save_path = '%s/%s_%s_%dCls_%s_Patch%03d_model%s/%02d_%s/'%(self.cfg['base_path'],self.cfg['weights_pre'],cls_related_info,self.cfg['num_class'],self.current_day,self.input_shape[0],self.cfg['model_name'],model_order,self.cfg['weight_memo']) if (not os.path.exists(self.save_path)) and self.mode == 'training': os.makedirs(self.save_path) if type(self.cfg['cls_label_index'])==list: self.cfg['num_task'] = len(self.cfg['cls_label_index']) else: self.cfg['num_task'] = 1 def _buildModel(self): cfg = self.model_params cfg['num_task'] = self.cfg['num_task'] cfg['num_class'] = self.cfg['num_class'] if cfg['num_class'] == 2: cfg['num_class'] = 1 self.net = build_model(self.cfg['model_name'],cfg) logger.info("After building model") self.net.apply(weight_init) logger.info("After Weight Init") def _compileModel(self): self.optimizer = optim.Adam(self.net.parameters(),lr=self.cfg['learning_rate']) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=self.cfg['patience']) if self.cfg['num_class']>2: self.criterion = nn.CrossEntropyLoss() else: self.criterion = nn.BCEWithLogitsLoss() conf_paths = [self.filename,self.model_param_name] if self.mode == 'training': conf_paths += [self.trainfile_conf_path,self.valfile_conf_path] for conf_path in conf_paths: filename = conf_path.split('/')[-1] target_save_path = os.path.join(self.save_path, filename) shutil.copy(conf_path, target_save_path) else: conf_paths += [self.test_conf_path] def _DecodeDataParam(self,filename): paths = LoadJson(filename) data_paths = [record['data_path'] for record in paths] info_paths = [record['info_path'] for record in paths] return {'data_paths':data_paths, 'info_paths':info_paths} def _LoadData(self,mode): if mode == 'training': self.train_data,self.train_info = load_data([self.trainfiles['data_paths'],self.trainfiles['info_paths']]) self.val_data,self.val_info = load_data([self.valfiles['data_paths'],self.valfiles['info_paths']]) else: self.test_data,self.test_info = load_data([self.testfiles['data_paths'],self.testfiles['info_paths']]) def _trainModel(self): ########## Add parameters gamma = self.cfg['gamma'] if 'gamma' in self.cfg.keys() else 1 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') nonfix_crop=self.cfg['nonfix_crop'] if 'nonfix_crop' in self.cfg.keys() else False train_aug_choice = self.cfg['train_aug_choice'] if 'train_aug_choice' in self.cfg.keys() else False train_dataset = SurDataSet(self.train_data,self.train_info,cls_index=self.cfg['cls_label_index'],cls_map_dict=self.cls_map_dict, diameter_index=self.cfg['diameter_index'],num_class=self.cfg['num_class'],HU_min=self.cfg['HU_min'], HU_max=self.cfg['HU_max'],input_shape=self.input_shape, cls_types=self.ori_cls_types,aug=train_aug_choice, nonfix_crop=nonfix_crop, diameter_max_th = self.cfg['diameter_max_th'], diameter_min_th = self.cfg['diameter_min_th']) val_dataset = SurDataSet(self.val_data,self.val_info,cls_index=self.cfg['cls_label_index'],cls_map_dict=self.cls_map_dict, diameter_index=self.cfg['diameter_index'],num_class=self.cfg['num_class'],HU_min=self.cfg['HU_min'], HU_max=self.cfg['HU_max'],input_shape=self.input_shape, cls_types=self.ori_cls_types, nonfix_crop=nonfix_crop, diameter_max_th = self.cfg['diameter_max_th'], diameter_min_th = self.cfg['diameter_min_th']) if self.cfg['focal_loss'] and self.cfg['num_task']==1: ratios =train_dataset._GenerateRatios() if self.cfg['num_class']<=2: alpha = ratios[0] print ('='*60) print ('alpha is ',alpha) train_focal_loss = self.cfg['train_focal_loss'] if 'train_focal_loss' in self.cfg.keys() else self.cfg['focal_loss'] loss_choice = self.cfg['loss_func'] if 'loss_func' in self.cfg.keys() else 'focal' print ('loss choice is ',loss_choice) if loss_choice != 'focal': train_focal_loss = False if train_focal_loss: self.criterion = BCEFocalLoss(alpha=alpha,gamma=gamma) elif loss_choice == 'peer': self.criterion = PeerLoss(alpha=alpha,gamma=gamma) elif loss_choice == 'MIFocalLoss': self.criterion = MIFocalLoss(alpha=alpha,gamma=gamma) else: alpha = ratios sigmoid_choice = self.cfg['sigmoid_choice'] if 'sigmoid_choice' in self.cfg.keys() else False ######### sigmoid + bce loss self.criterion = FocalLoss(self.cfg['num_class'],alpha=ratios,gamma=gamma) else: alpha = None # ratios = train_dataset._GenerateRatios() balance_data_kws = self.cfg['balance_data'] if 'balance_data' in self.cfg.keys() else [0,0] modes = ['val' if val==0 else 'train' for val in balance_data_kws] train_dataset._GenerateTrainData(mode=modes[0]) val_dataset._GenerateTrainData(mode=modes[1]) train_loader = DataLoader(train_dataset,batch_size=self.cfg['batch_size'],shuffle=True) val_loader = DataLoader(val_dataset,batch_size=self.cfg['batch_size'],shuffle=True) lr=self.cfg['learning_rate'] batch_size = self.cfg['batch_size'] # writer = SummaryWriter(comment=f"LR_{lr}_BS_{batch_size}") global_step = 0 ##### add config epochs = self.cfg['num_epoch'] criterion = self.criterion for epoch in range(epochs): self.net.train() self.net.to(device) self.net.cuda() epoch_loss = 0 print ('size of train dataset is ',len(train_dataset)) with tqdm(total=int(len(train_dataset)/batch_size)+1,desc=f'Epoch{epoch+1}/{epochs}',unit='img') as pbar: batch_count = 0 for batch in train_loader: batch_count +=1 imgs = batch[0][0] label = batch[1][0] imgs = imgs.cuda() if self.cfg['num_class']>2: label = label.cuda() label = label.to(device=device,dtype=torch.long) else: if self.cfg['num_task']>1: label = [val.cuda() for val in label] label = [val.to(device=device,dtype=torch.float32) for val in label] else: label = label.cuda() label = label.to(device=device,dtype=torch.float32) imgs = imgs.to(device=device,dtype=torch.float32) prediction = self.net(imgs) loss_values = [] if self.cfg['num_task']>1: for x,y in zip(label,prediction): probs = torch.sigmoid(y) y = y.squeeze(-1) loss = criterion([y,x]) epoch_loss += loss.item() loss_values.append(loss.item()) else: # print ('prediction before',prediction) if self.cfg['num_class']>2: prediction = prediction else: if type(prediction) == list: prediction = prediction[0] prediction = prediction.squeeze(-1) loss = criterion([prediction,label]) epoch_loss += loss.item() self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(self.net.parameters(),1e-1) self.optimizer.step() global_step += 1 # if(global_step %(len(train_dataset)//(10*batch_size)))==0: # for tag,value in self.net.named_parameters(): # tag = tag.replace('.','/') # writer.add_histogram('weights/'+tag,value.data.cpu().numpy,global_step) # writer.add_histogram('grads/'+tag,value.grad.data.cpu().numpy,global_step) ############ Add eval net func(for samples in val generator make prediction and calculating loss) ############## Stragety 1. pbar.update(1) if self.cfg['select_samples'] == 'partial': loss_list = self._CalculateLossTrain(train_dataset,device) train_dataset._updateIndices(loss_list,ratio_pos=self.cfg['ratio_pos'],num_hard=100) elif self.cfg['select_samples'] == 'boost': ############# Stragety 2. _CalculateDiffTrain loss_list = self._CalculateLossTrain(train_dataset,device) initial_epoch = self.cfg['initial_epoch'] if 'initial_epoch' in self.cfg.keys() else 0 epoch_interval = self.cfg['epoch_interval'] if 'epoch_interval' in self.cfg.keys() else 1 if epoch>=initial_epoch: if (epoch-epoch_interval)%epoch_interval == 0: diff_list = self._CalculateDiffTrain(train_dataset,device) increase_choice = self.cfg['increase_choice'] if 'increase_choice' in self.cfg.keys() else False train_dataset._updateWeights(diff_list,increase=increase_choice) train_dataset._SelectSamplesBasedOnWeights() new_ratios = train_dataset._GenerateRatios() if self.cfg['num_class']<2: self.criterion = BCEFocalLoss(alpha=new_ratios[0],gamma=gamma) # else: # self.criterion = FocalLoss(self.cfg['num_class'],alpha=new_ratios,gamma=gamma) print ('batch_count is',batch_count) random.shuffle(train_dataset.indices) val_score = eval_net(self.net, val_dataset,val_loader, device,alpha=alpha,cfg=self.cfg)*100 ############ change lr based on val result self.scheduler.step(val_score) ###########3 param_groups? # writer.add_scalar('learning_rate',self.optimizer.param_groups[0]['lr'],global_step) print ('epoch_loss is',epoch_loss,len(train_dataset)) logger.info("Train loss: {}".format(epoch_loss/float(len(train_dataset))*100)) logging.info("Validation loss: {}".format(val_score)) if self.cfg['save_cp']: model_save_path_current_epoch = self.save_path + f'/CP_epoch{epoch + 1}_loss_%.2f.pth' % val_score torch.save(self.net.state_dict(), model_save_path_current_epoch) logging.info(f'Checkpoint{epoch + 1} saved to %s!' % model_save_path_current_epoch) if self.train_writer: self.train_writer.add_scalar('loss', epoch_loss/float(len(train_dataset))*100, epoch) if self.val_writer: self.val_writer.add_scalar('loss', val_score, epoch) def _testModel(self): if '*' in self.cfg['weight_path']: model_paths = sorted(glob.glob(self.cfg['weight_path'])) model_paths = [path for path in model_paths if '.pth' in path] else: model_paths = [self.cfg['weight_path']] results = [] for model_path in model_paths: print ('='*60) print ('model_path is ',model_path) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.net.to(device=device) self.net.load_state_dict(torch.load(model_path,map_location=device)) logging.info('Model loaded !') ########## 将模型设置为eval 固定BN/dropout的参数 self.net.eval() nonfix_crop=self.cfg['nonfix_crop'] if 'nonfix_crop' in self.cfg.keys() else False test_dataset = SurDataSet(self.test_data,self.test_info,cls_index=self.cfg['cls_label_index'],cls_map_dict=self.cls_map_dict, diameter_index=self.cfg['diameter_index'],num_class=self.cfg['num_class'],HU_min=self.cfg['HU_min'], HU_max=self.cfg['HU_max'],input_shape=self.input_shape, cls_types=self.ori_cls_types, nonfix_crop = nonfix_crop, aug=False, mode='test', diameter_max_th = self.cfg['diameter_max_th'], diameter_min_th = self.cfg['diameter_min_th']) test_dataset._GenerateTrainData() self.label_list = [] self.result_list = [] self.badcase_lists = [] for idx in range(len(test_dataset)): current_data = test_dataset[idx] img = current_data[0][0] label = current_data[1][0] img = img[np.newaxis,...] img = torch.from_numpy(img) img = img.to(device=device,dtype=torch.float32) with torch.no_grad(): output = self.net(img) if type(output) == list: output = output[0] if self.cfg['num_task']==1: if self.net.num_class>2: probs = torch.softmax(output,dim=1) else: probs = torch.sigmoid(output) probs = probs.squeeze(0) probs = probs.data.cpu().numpy() else: current_probs = [] for task_idx in range(self.cfg['num_task']): if self.net.num_class>1: probs = F.softmax(output[task_idx],dim=1) else: probs = torch.sigmoid(output[task_idx]) probs = probs.squeeze(0) current_probs.append(probs.cpu().numpy()[0]) probs = current_probs self.label_list.append(label) self.result_list.append(probs) print ('length of label and result',len(self.label_list),len(self.result_list)) print ('shape of result list is ',np.array(self.result_list).shape) print ('max val of self.result_list',np.amin(self.result_list),np.amax(self.result_list)) if self.net.num_class<2: if self.cfg['num_task']==1: th_result = [val>0.5 for val in self.result_list] CalculateClsScore(th_result,self.label_list) auc_val = CalculateAuc(self.result_list,self.label_list) results.append(auc_val) if len(model_paths)==1: CalculateClsScoreByTh(self.result_list,self.label_list,acc_flag=True) else: for task_idx in range(self.cfg['num_task']): print ('='*60) print ('Current task is %s'%self.cls_name_map_dict[self.cfg['cls_label_index'][task_idx]]) current_label_list = np.squeeze(np.array([val[task_idx] for val in self.label_list])) current_result_list = np.squeeze(np.array([val[task_idx] for val in self.result_list])) th_result = np.squeeze(np.array([val>0.5 for val in current_result_list])) print (np.array(current_label_list).shape,np.array(current_result_list).shape,np.array(th_result).shape) CalculateClsScore(th_result,current_label_list) auc_val = CalculateAuc(current_result_list,current_label_list) else: result = np.squeeze(np.array(self.result_list)) result = np.argmax(result,axis=1) kappa_val = CalculateClsScore(result,self.label_list) results.append(kappa_val) # self.badcase_lists = [case_id for case_id in range(len(test_dataset)) if self.result_list[case_id]!=self.label_list[case_id]] if len(results)>0: print ('results is ',results) pos = np.argmax(results) print ('best weights is ',model_paths[pos]) self.test_dataset = test_dataset # def BadCaseVis(self): # for idx in self.badcase_lists: # current_data = test_dataset[idx] # img = np.squeeze(current_data[0]) # label = current_data[1] # center_slice_idx = int(current_data.shape[0]/2) # slice_range = 2 # print ('='*60) # print ('Current label is ',label) # for slice_idx in range(center_slice_idx-slice_range,center_slice_idx+slice_range): # plt.imshow(img[slice_idx],cmap='bone') # plt.show() def ParametersDecode(self): ''' Generate model parameters ''' self.model_param_name = os.path.join(self.basic_file_path,self.cfg['model_params']) self.model_params = LoadJson(self.model_param_name)[self.cfg['model_name']] def _CalculateLossTrain(self,train_dataset,device): loss_list = [] print ('number of data inside train dataset is ',len(train_dataset.data)) train_dataset.mode = 'test' self.net.eval() for idx in range(len(train_dataset.data)): data,label,_ = train_dataset[idx] data = data[0] data = data[np.newaxis,...] data = torch.from_numpy(data) label = label[0] label = np.array([label]) label = torch.from_numpy(label) data = data.to(device=device,dtype=torch.float32) if self.cfg['num_task']==1: label = label.to(device=device,dtype=torch.float32) else: label = [val.to(device=device,dtype=torch.float32) for val in label] with torch.no_grad(): prediction = self.net(data) #prediction = prediction.squeeze(-1) if self.cfg['num_task']==1: prediction = prediction.squeeze(-1) if self.cfg['num_class']>2: label = label.to(device=device,dtype=torch.long) loss_list.append([idx,self.criterion([prediction,label]).item()]) else: current_loss = [] for x,y in zip(prediction,label): x = x.squeeze(-1) current_loss.append(self.criterion([x,y]).item()) loss_list.append([idx,current_loss]) train_dataset.mode = 'train' self.net.train() return loss_list def _CalculateDiffTrain(self,train_dataset,device): diff_list = [] print ('number of data inside train dataset is ',len(train_dataset.data)) train_dataset.mode = 'test' self.net.eval() for idx in range(len(train_dataset.data)): data,label,_ = train_dataset[idx] data = data[0] data = data[np.newaxis,...] data = torch.from_numpy(data) label = label[0] label = np.array([label]) label = torch.from_numpy(label) data = data.to(device=device,dtype=torch.float32) if self.cfg['num_task']==1: label = label.to(device=device,dtype=torch.float32) else: label = [val.to(device=device,dtype=torch.float32) for val in label] with torch.no_grad(): prediction = self.net(data) prediction = prediction.squeeze(-1) prediction = torch.sigmoid(prediction) diff = abs(prediction-label) diff_list.append([idx,diff]) train_dataset.mode = 'train' self.net.train() return diff_list