# -*- coding:utf-8 -*- import os import sys import glob import math import random import numpy as np import pandas as pd import scipy.ndimage as nd from scipy.ndimage import zoom from keras.utils import to_categorical from torch.utils.data import Dataset from collections import Counter, defaultdict from ImagePreprocess import * from ReadData import load_data,load_single_data from data_aug_helper import data_augmentation,random_add_val class SegGenCOVID(Dataset): def __init__(self,df_path,**kwargs): super(SegGenCOVID,self).__init__() self.aug = kwargs.get('aug',False) self.mode = kwargs.get('mode','training') self.clip_choice = kwargs.get('clip_choice',True) self.HU_min = kwargs.get('HU_min',-1024) self.HU_max = kwargs.get('HU_max',1024) self.stride_ratio = kwargs.get('stride_ratio',8) self.input_size = kwargs.get('input_size') self.estimate_patch_per_img = kwargs.get('estimate_patch_per_img',12) print ('self.estimate_patch_per_img is ',self.estimate_patch_per_img) repeat_times = kwargs.get('repeat_times',4) self.df = pd.read_csv(df_path) print ('columns',self.df.columns) self.indices = np.arange(len(self.df)) print ('length of self.indices before is ',len(self.indices)) if self.mode=='training': print ('='*60) print ('repeat indices') self.indices = np.repeat(self.indices,repeats=repeat_times*self.estimate_patch_per_img) elif self.mode=='val': self.indices = np.repeat(self.indices,repeats=self.estimate_patch_per_img) random.shuffle(self.indices) def __len__(self): return len(self.indices) def __getitem__(self,index): if self.mode == 'test': data_idx = index else: data_idx = self.indices[index] return self.getSingleCase(data_idx) def _CalculateNumPatch(self,mask_shape): region_size = [val1-val2 for val1,val2 in zip(mask_shape,self.input_size)] num_patches = [int(math.ceil(val/float(self.stride_ratio))) for val in region_size] return num_patches def _GeneratePatchIndex(self,num_patches): patch_idx = [] for idx in range(len(num_patches)): current_idx = np.random.randint(low=0,high=num_patches[idx],size=1)[0] patch_idx.append(current_idx) return patch_idx def _GetBound(self,lungmask): margins = [10,20,20] points = np.where(lungmask>0) bounds = [] for idx in range(len(points)): bound_min = max(min(points[idx])-margins[idx],0) bound_max = max(min(points[idx])+margins[idx],lungmask.shape[idx]-1)+1 bounds += [bound_min,bound_max] return bounds def _GetValidRegion(self,img,bound): return img[bound[0]:bound[1],bound[2]:bound[3],bound[4]:bound[5]] def getSingleCase(self,data_idx): patch_volume = self.input_size[0]*self.input_size[1]*self.input_size[2] HU_min,HU_max = self.HU_min,self.HU_max mask_path,image_path,lungmask_path = self.df.mask_path[data_idx],self.df.image_path[data_idx],self.df.lung_mask_path[data_idx] ''' Load mask & lungmask ''' mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_path)) mask = np.array(mask>0).astype(np.uint8) lungmask = sitk.GetArrayFromImage(sitk.ReadImage(lungmask_path)) lungmask = np.array(lungmask>0).astype(np.uint8) image = sitk.GetArrayFromImage(sitk.ReadImage(image_path)) ''' crop mask & lungmask based on valid lungmask region ''' bounds = self._GetBound(lungmask) mask = self._GetValidRegion(mask,bounds) lungmask = self._GetValidRegion(lungmask,bounds) num_patches = self._CalculateNumPatch(mask.shape) sum_of_patch = 0 while sum_of_patch/float(patch_volume)<0.01: patch_indices = self._GeneratePatchIndex(num_patches) image_patch,mask_patch = self._GeneratePatchWithIndices(image,mask,patch_indices) sum_of_patch = np.sum(mask_patch) return image_patch,mask_patch def _GeneratePatchWithIndices(self,image,mask,patch_indices): HU_min,HU_max = self.HU_min,self.HU_max ''' calculate patch bound ''' patch_centers = [int(patch_indices[idx]*self.stride_ratio+self.input_size[idx]/2.0) for idx in range(3)] patch_mins = [int(max(self.input_size[idx]/2,patch_centers[idx]-0.5*self.stride_ratio)) for idx in range(len(patch_centers))] patch_maxs = [int(min(mask.shape[idx]-self.input_size[idx]/2,patch_centers[idx]+0.5*self.stride_ratio)) for idx in range(len(patch_centers))] temp = [] for idx in range(len(patch_centers)): temp.append(np.random.randint(low=patch_mins[idx],high=patch_maxs[idx],size=1)[0]) patch_centers = temp ''' crop patch ''' margins = [int(val/2) for val in self.input_size] image_patch = image[patch_centers[0]-margins[0]:patch_centers[0]+margins[0], patch_centers[1]-margins[1]:patch_centers[1]+margins[1], patch_centers[2]-margins[2]:patch_centers[2]+margins[2]] mask_patch = mask[patch_centers[0]-margins[0]:patch_centers[0]+margins[0], patch_centers[1]-margins[1]:patch_centers[1]+margins[1], patch_centers[2]-margins[2]:patch_centers[2]+margins[2]] if self.clip_choice: image_patch = np.clip(image_patch,HU_min,HU_max) image_patch = (image_patch - HU_min)/float(HU_max - HU_min) ''' aug patch if necessary ''' if self.aug: if np.random.choice([0,1]): image_patch = np.swapaxes(image_patch,1,2) mask_patch = np.swapaxes(mask_patch,1,2) if np.random.choice([0,1]): image_patch = random_add_val(image_patch) if np.random.choice([0,1]): image_patch,mask_patch = data_augmentation(image_patch,mask_patch,expend_choice=False) image_patch,mask_patch = image_patch[np.newaxis,...],mask_patch[np.newaxis,...] new_image_patch = image_patch.copy() new_mask_patch = mask_patch.copy() new_mask_patch = np.array(new_mask_patch>0).astype(np.uint8) return new_image_patch,new_mask_patch def _transImage(self,img,label): # final_image,final_mask,new_spacing = imagepari_target_shape_respacing(final_image,final_mask,ori_spacing,current_zoom_shape) img = img.numpy() label = label.numpy() trans_list = {} ########### flip flip_list = [] cnt = 3 prob = 0.5 while random.random() < prob and cnt > 0: degree = random.choice([0, 1, 2]) img = np.flip(img, axis=degree) label = np.flip(label, axis=degree) cnt = cnt - 1 flip_list.append(degree) trans_list['flip_list'] = flip_list ########## swap axis if np.random.choice([0,1]): img = np.swapaxes(img,2,3) label = np.swapaxes(label,2,3) trans_list['swap'] = True else: trans_list['swap'] = False ############## zoom ''' todo: add this part ''' copy_img = img.copy() copy_label = label.copy() return copy_img,copy_label,trans_list def _transWithFunc(self,imgs,func_list): # print ('length of imgs is ',len(imgs)) outputs = [] if func_list['swap']: for img in imgs: new_img = np.swapaxes(img,1,2) outputs.append(new_img) else: outputs = [val for val in imgs] for degree in func_list['flip_list']: for img_idx,img in enumerate(outputs): new_img = np.flip(img,axis=degree) outputs[img_idx] = new_img outputs = [val.copy() for val in outputs] # print ('length of outputs',outputs[0].shape) return outputs