# -*- coding:utf-8 -*- import os import sys import time import glob import math import random import logging import numpy as np import pandas as pd import scipy.ndimage as nd from scipy.ndimage import zoom from keras.utils import to_categorical from torch.utils.data import Dataset from collections import Counter, defaultdict from ImagePreprocess import * from ReadData import load_data,load_single_data from data_aug_helper import data_augmentation,random_add_val from SplitAndCombine import * def ReadCSV(path): if 'csv' in path: df = pd.read_csv(path,encoding='utf_8_sig') return df elif 'xlsx' in path: df = pd.read_excel(path,encoding='utf_8_sig') return df def LoadDFs(paths): result = [] for path in paths: if ('csv' not in path) and ('xlsx' not in path): continue df = ReadCSV(path) result.append(df) if len(result)>0: return pd.concat(result,ignore_index=True) class SegGenCOVID(Dataset): def __init__(self,df_paths,**kwargs): super(SegGenCOVID,self).__init__() self.aug = kwargs.get('aug',False) self.mode = kwargs.get('mode','training') self.clip_choice = kwargs.get('clip_choice',True) self.HU_min = kwargs.get('HU_min',-1024) self.HU_max = kwargs.get('HU_max',1024) self.input_size = kwargs.get('input_size') self.estimate_patch_per_img = kwargs.get('estimate_patch_per_img',12) self.infer_stride = kwargs.get('infer_stride',2) self.use_npy = kwargs.get('use_npy',False) self.norm_func = kwargs.get('norm_func','HU') self.val_mode = kwargs.get('val_mode',1) self.learn_feature = kwargs.get('learn_feature',False) self.bbox_margin = kwargs.get('bbox_margin',0) self.patch_pos_ratio_range = kwargs.get('patch_pos_ratio_range',[-0.1,1.1]) self.cache_size = kwargs.get('cache_size',500) self.logger = kwargs.get('logger') self.union_val = 0.1 self.badcase_list = ['1.2.840.113619.2.55.3.279715092.193.1581059882.969.4', '1.2.840.113619.2.55.3.279715092.193.1581059859.976.4'] repeat_times = kwargs.get('repeat_times',4) area_th_value = kwargs.get('area_th_val',0) self.current_epoch = 0 self.semi_udpate_weights_start = float('inf') self.cache = [] self.df = LoadDFs(df_paths) if 'area' in self.df.columns: self.df = self.df[self.df['area']>=area_th_value] self.df.reset_index(inplace=True) if 'badcase_flag' in self.df.columns: self.df = self.df[self.df['badcase_flag']==0] self.df.reset_index(inplace=True) self.indices = np.arange(len(self.df)) if self.mode=='training': self._repeatDF() self.indices = np.arange(len(self.df)) self.indices = np.repeat(self.indices,repeats=repeat_times) elif self.mode=='val' and self.val_mode==1: self.indices = np.repeat(self.indices,repeats=repeat_times) random.shuffle(self.indices) def __len__(self): return len(self.indices) def __getitem__(self,index): if self.mode == 'test' or (self.val_mode!=1 and self.mode=='val'): data_idx = index return self.GetInferCase(data_idx) else: data_idx = self.indices[index] if self.df.uid[data_idx] in self.badcase_list: data_idx = self.indices[(index+1)%len(self.indices)] result = self.getSingleCase(data_idx) if self.cache_size>0: self._updateCache(result) result = list(result) result[0] = result[0].astype(np.float64) result[1] = result[1].astype(np.float32) return result def _updateCache(self,result): if len(self.cache)==self.cache_size: self.cache = self.cache[1:] self.cache.append(result) def _repeatDF(self): max_area = 10000000000000 if self.estimate_patch_per_img>0: columns = [u'uid', u'image_path', u'mask_path', u'lung_mask_path', u'z_min', u'z_max', u'y_min', u'y_max', u'x_min', u'x_max', u'area'] final_records = [] for record in self.df[columns].values: area = record[-1] if area>=max_area: repeat_times = 5 else: repeat_times = math.ceil(area/float(self.estimate_patch_per_img)) for idx in range(repeat_times): final_records.append(record) self.df = pd.DataFrame(final_records,columns=columns) num_neg = len(self.df[self.df['area']==max_area]) num_pos = len(self.df)-num_neg print ('num pos and num neg is ',num_pos,num_neg) def _setEpoch(self,current_epoch,semi_udpate_weights_start): self.current_epoch = current_epoch self.semi_udpate_weights_start = semi_udpate_weights_start def _GeneratePatchIndex(self,num_patches): patch_idx = [] for idx in range(len(num_patches)): current_idx = np.random.randint(low=0,high=num_patches[idx],size=1)[0] patch_idx.append(current_idx) return patch_idx def _GetBound(self,lungmask): margins = [10,20,20] points = np.where(lungmask>0) bounds = [] for idx in range(len(points)): bound_min = max(min(points[idx])-margins[idx],0) bound_max = min(max(points[idx])+margins[idx],lungmask.shape[idx]-1)+1 bounds += [bound_min,bound_max] return bounds def _GetValidRegion(self,img,bound): return img[bound[0]:bound[1],bound[2]:bound[3],bound[4]:bound[5]] def _loadNPYData(self,path): npy_path = path.replace('NII_1.0','1.0_NPY').replace('.nii.gz','.npy') # print ('npy_path is',npy_path) if os.path.exists(npy_path): data = np.load(npy_path) else: data = sitk.GetArrayFromImage(sitk.ReadImage(path)) return data def GetInferCase(self,data_idx): HU_min,HU_max = self.HU_min,self.HU_max mask_path,image_path,lungmask_path = self.df.mask_path[data_idx],self.df.image_path[data_idx],self.df.lung_mask_path[data_idx] print ('image_path is',image_path) new_mask_path,new_img_path,new_lungmask_path = self._replace_path(mask_path),self._replace_path(image_path),self._replace_path(lungmask_path) if not self.use_npy: mask = sitk.GetArrayFromImage(sitk.ReadImage(new_mask_path)) lungmask = sitk.GetArrayFromImage(sitk.ReadImage(new_lungmask_path)) image = sitk.GetArrayFromImage(sitk.ReadImage(new_img_path)) else: mask = self._loadNPYData(new_mask_path) lungmask = self._loadNPYData(new_lungmask_path) image = self._loadNPYData(new_img_path) print ('img shape is',image.shape,lungmask.shape,mask.shape) mask = np.array(mask>0).astype(np.uint8) lungmask = np.array(lungmask>0).astype(np.uint8) ori_image = image ori_image = np.clip(ori_image,HU_min,HU_max) bounds = self._GetBound(lungmask) mask = self._GetValidRegion(mask,bounds) image = self._GetValidRegion(image,bounds) mask_shape = mask.shape image = np.clip(image,HU_min,HU_max) image = (image-HU_min)/float(HU_max-HU_min) output_img_patches,output_mask_patches,pad_vals,pad_mask = SplitPatch(mask,image,self.input_size,self.infer_stride) return output_img_patches,output_mask_patches,pad_mask,pad_vals,bounds,ori_image,lungmask,mask_path def _CombinePrediction(self,predictions,pad_mask,pad_vals): # CombinePatch(predictions,pad_mask,pad_vals,stride=2) result = CombinePatch(predictions,pad_mask,pad_vals,self.infer_stride) mask_shape = [] pad_mask_shape = pad_mask.shape for idx in range(3): val = pad_mask_shape[idx]-sum(pad_vals[idx]) mask_shape.append(val) mask = pad_mask[pad_vals[0][0]:pad_vals[0][0]+mask_shape[0], pad_vals[1][0]:pad_vals[1][0]+mask_shape[1], pad_vals[2][0]:pad_vals[2][0]+mask_shape[2]] return result,mask def _replace_path(self,path): if os.path.exists(path.replace('fileser','ssd')): path = path.replace('fileser','ssd') return path def _getNoneZeroPoint(self,mask): points = np.where(mask>0) if len(points)>0 and len(points[0])>0: length = len(points[0]) point_idx = np.random.choice(np.arange(length),size=1)[0] result = [] for idx in range(len(points)): result.append(points[idx][point_idx]) return result return None def getSingleCase(self,data_idx): import time HU_min,HU_max = self.HU_min,self.HU_max mask_path,image_path,lungmask_path = self.df.mask_path[data_idx],self.df.image_path[data_idx],self.df.lung_mask_path[data_idx] ''' combine_V2路径下的数据是intersection为1,union-intersection区域为0.1,用于目前训练 ''' if 'combine' in mask_path: mask_path = mask_path.replace('combine','combine_V2') ''' 默认是阳性patch ''' cls_label = 1 ''' 如果表格里的mask_path设置为了lungmask的路径,阴性数据 ''' if mask_path == lungmask_path: neg_flag = True else: neg_flag = False ''' 面积设置成inf的component是[无mask标注的阳性数据] ''' nonlabel_flag = False area = self.df.area[data_idx] if area == float('inf'): nonlabel_flag = True ''' 训练semi_udpate_weights_start后才能加入这种patch,之前都要重新再取一个patch ''' if self.current_epoch0)]) std_val = np.std(new_img[np.where(new_lungmask>0)]) image = np.clip(image,self.HU_min,self.HU_max) image = (image-mean_val)/std_val patch_middle = time.time() ''' 读取表中的bbox ''' bbox = self.df.z_min[data_idx],self.df.z_max[data_idx],self.df.y_min[data_idx],self.df.y_max[data_idx],self.df.x_min[data_idx],self.df.x_max[data_idx] if bbox[0]==0 and bbox[3]==0: bounds = self._GetBound(mask) bbox = [bounds[0],bounds[2],bounds[4],bounds[1],bounds[3],bounds[5]] ################## Compress bbox with margin bbox = [min(bbox[0]+self.bbox_margin,bbox[3]-self.bbox_margin-1), min(bbox[1]+self.bbox_margin,bbox[4]-self.bbox_margin-1), min(bbox[2]+self.bbox_margin,bbox[5]-self.bbox_margin-1), max(bbox[0]+self.bbox_margin+1,bbox[3]-self.bbox_margin), max(bbox[1]+self.bbox_margin+1,bbox[4]-self.bbox_margin), max(bbox[2]+self.bbox_margin+1,bbox[5]-self.bbox_margin)] patch_middle_2 = time.time() ''' calcualte center indices ''' if neg_flag or nonlabel_flag: return self._getPatch(bbox,image,mask,lungmask,neg_flag,nonlabel_flag)[:3] else: pos_ratio = 0 patch_loop_idx = 0 ''' 要求patch内的前景比例要位于某一个区域 ''' while (pos_ratio<=self.patch_pos_ratio_range[0] or pos_ratio>=self.patch_pos_ratio_range[1]) or patch_loop_idx==0: image_patch,mask_patch,cls_label,pos_ratio = self._getPatch(bbox,image,mask,lungmask,neg_flag,nonlabel_flag,patch_loop_idx) patch_loop_idx+=1 patch_end = time.time() self.logger.debug('it task %.2f to crop patch of data %s'%(patch_end-patch_start,self.df.uid[data_idx])) return image_patch,mask_patch,cls_label def _getPatch(self,bbox,image,mask,lungmask=None,neg_flag=False,nonlabel_flag=False,patch_loop_idx=0): import time get_patch_start_time = time.time() cls_label = 1 ''' 如果是没有肺炎标注的数据,直接取肺部mask非0区域 如果是有肺炎mask标注的数据,但是取了多次都没满足条件,也从肺炎mask非0区域取中心点 ''' if neg_flag or nonlabel_flag or patch_loop_idx>5: centers = self._getNoneZeroPoint(lungmask) if centers is not None: center_z,center_y,center_x = centers else: ''' 在bbox中随机取中i先弄点 ''' centers = [] while len(centers)==0: center_z = np.random.randint(bbox[0],bbox[3],size=1)[0] center_y = np.random.randint(bbox[1],bbox[4],size=1)[0] center_x = np.random.randint(bbox[2],bbox[5],size=1)[0] centers = [center_z,center_y,center_x] center_z,center_y,center_x = centers center_z = min(max(center_z,int(self.input_size[0]/2)),mask.shape[0]-int(self.input_size[0]/2)-1) center_y = min(max(center_y,int(self.input_size[1]/2)),mask.shape[1]-int(self.input_size[1]/2)-1) center_x = min(max(center_x,int(self.input_size[2]/2)),mask.shape[2]-int(self.input_size[2]/2)-1) ''' 取patch ''' patch_centers = [center_z,center_y,center_x] image_patch,mask_patch = self._GeneratePatchWithIndices(image,mask,patch_indices=None,patch_centers=patch_centers) if self.union_val not in np.unique(mask_patch): mask_patch = np.array(mask_patch>0) ''' 区分无肺炎mask的数据 ''' if neg_flag: mask_patch = np.zeros_like(mask_patch) cls_label = 0 if nonlabel_flag: mask_patch = np.ones_like(mask_patch)*(100) '''训练AE''' if self.learn_feature: mask_patch = image_patch if np.sum(mask_patch)==0: cls_label = 0 get_patch_middle_time = time.time() sum_of_whole_patch = np.sum(np.ones_like(mask_patch)) sum_of_mask = np.sum(mask_patch) pos_ratio = float(sum_of_mask)/float(sum_of_whole_patch) get_patch_end_time = time.time() return image_patch,mask_patch,cls_label,pos_ratio def _GeneratePatchWithIndices(self,image,mask,patch_indices,patch_centers=None): HU_min,HU_max = self.HU_min,self.HU_max ''' crop patch ''' # print ('patch_centers is ',patch_centers) margins = [int(val/2) for val in self.input_size] image_patch = image[patch_centers[0]-margins[0]:patch_centers[0]+margins[0], patch_centers[1]-margins[1]:patch_centers[1]+margins[1], patch_centers[2]-margins[2]:patch_centers[2]+margins[2]] mask_patch = mask[patch_centers[0]-margins[0]:patch_centers[0]+margins[0], patch_centers[1]-margins[1]:patch_centers[1]+margins[1], patch_centers[2]-margins[2]:patch_centers[2]+margins[2]] if self.norm_func=='HU': if self.clip_choice: image_patch = np.clip(image_patch,HU_min,HU_max) image_patch = (image_patch - HU_min)/float(HU_max - HU_min) ''' aug patch if necessary ''' if self.aug: if np.random.choice([0,1]): image_patch = np.swapaxes(image_patch,1,2) mask_patch = np.swapaxes(mask_patch,1,2) if np.random.choice([0,1]): image_patch = random_add_val(image_patch) if np.random.choice([0,1]): image_patch,mask_patch = data_augmentation(image_patch,mask_patch,expend_choice=False) image_patch,mask_patch = image_patch[np.newaxis,...],mask_patch[np.newaxis,...] new_image_patch = image_patch.copy() new_mask_patch = mask_patch.copy() new_mask_patch = np.array(new_mask_patch) # new_mask_patch = np.array(new_mask_patch>0).astype(np.uint8) return new_image_patch,new_mask_patch def _padBackToOriSize(self,img,mask,pred,pad_vals,bounds): img = np.squeeze(img) mask = np.squeeze(mask) pred = np.squeeze(pred) target_shape = img.shape mask_shape = mask.shape left_pad_vals = [bounds[0],bounds[2],bounds[4]] right_pad_vals = [target_shape[idx]-mask_shape[idx]-left_pad_vals[idx] for idx in range(3)] pad_vals = [[left_pad_vals[idx],right_pad_vals[idx]] for idx in range(3)] pred = np.pad(pred,pad_vals,mode='constant') mask = np.pad(mask,pad_vals,mode='constant') return mask,pred def _transImage(self,img,label): # final_image,final_mask,new_spacing = imagepari_target_shape_respacing(final_image,final_mask,ori_spacing,current_zoom_shape) img = img.numpy() label = label.numpy() trans_list = {} ########### flip flip_list = [] cnt = 3 prob = 0.5 while random.random() < prob and cnt > 0: degree = random.choice([0, 1, 2]) img = np.flip(img, axis=degree) label = np.flip(label, axis=degree) cnt = cnt - 1 flip_list.append(degree) trans_list['flip_list'] = flip_list ########## swap axis if np.random.choice([0,1]): img = np.swapaxes(img,2,3) label = np.swapaxes(label,2,3) trans_list['swap'] = True else: trans_list['swap'] = False ############## zoom ''' todo: add this part ''' copy_img = img.copy() copy_label = label.copy() return copy_img,copy_label,trans_list def _transWithFunc(self,imgs,func_list): # print ('length of imgs is ',len(imgs)) outputs = [] if func_list['swap']: for img in imgs: new_img = np.swapaxes(img,1,2) outputs.append(new_img) else: outputs = [val for val in imgs] for degree in func_list['flip_list']: for img_idx,img in enumerate(outputs): new_img = np.flip(img,axis=degree) outputs[img_idx] = new_img outputs = [val.copy() for val in outputs] # print ('length of outputs',outputs[0].shape) return outputs