SegGenCOVID.py 8.19 KB
# -*- coding:utf-8 -*-
import os
import sys
import glob
import math
import random
import numpy as np
import pandas as pd
import scipy.ndimage as nd
from scipy.ndimage import zoom
from keras.utils import to_categorical
from torch.utils.data import Dataset
from collections import Counter, defaultdict

from ImagePreprocess import *
from ReadData import load_data,load_single_data
from data_aug_helper import data_augmentation,random_add_val


class SegGenCOVID(Dataset):
    def __init__(self,df_path,**kwargs):
        super(SegGenCOVID,self).__init__()
        
        self.aug = kwargs.get('aug',False)
        self.mode = kwargs.get('mode','training')
        
        self.clip_choice = kwargs.get('clip_choice',True)
        self.HU_min = kwargs.get('HU_min',-1024)
        self.HU_max = kwargs.get('HU_max',1024)
        self.stride_ratio = kwargs.get('stride_ratio',8)
        self.input_size = kwargs.get('input_size')
        self.estimate_patch_per_img = kwargs.get('estimate_patch_per_img',12)
        print ('self.estimate_patch_per_img is ',self.estimate_patch_per_img)

        repeat_times = kwargs.get('repeat_times',4)

        self.df = pd.read_csv(df_path)
        print ('columns',self.df.columns)

        self.indices = np.arange(len(self.df))
        print ('length of self.indices before is ',len(self.indices))
        
        if self.mode=='training':
            print ('='*60)
            print ('repeat indices')
            self.indices = np.repeat(self.indices,repeats=repeat_times*self.estimate_patch_per_img)
        elif self.mode=='val':
            self.indices =  np.repeat(self.indices,repeats=self.estimate_patch_per_img)
        random.shuffle(self.indices)
        
        
    def __len__(self):
        return len(self.indices)
        
    def __getitem__(self,index):
        if self.mode == 'test':
            data_idx = index
        else:
            data_idx = self.indices[index]
        return self.getSingleCase(data_idx)
    
    def _CalculateNumPatch(self,mask_shape):
        region_size = [val1-val2 for val1,val2 in zip(mask_shape,self.input_size)]
        num_patches = [int(math.ceil(val/float(self.stride_ratio))) for val in region_size]
        return num_patches
    
    def _GeneratePatchIndex(self,num_patches):
        patch_idx = []
        for idx in range(len(num_patches)):
            current_idx = np.random.randint(low=0,high=num_patches[idx],size=1)[0]
            patch_idx.append(current_idx)
        return patch_idx
    
    def _GetBound(self,lungmask):
        margins = [10,20,20]
        points = np.where(lungmask>0)
        bounds = []
        for idx in range(len(points)):
            bound_min = max(min(points[idx])-margins[idx],0)
            bound_max = max(min(points[idx])+margins[idx],lungmask.shape[idx]-1)+1
            bounds += [bound_min,bound_max]
        return bounds
    
    def _GetValidRegion(self,img,bound):
        return img[bound[0]:bound[1],bound[2]:bound[3],bound[4]:bound[5]]
    
        
    def getSingleCase(self,data_idx):
        
        patch_volume = self.input_size[0]*self.input_size[1]*self.input_size[2]
        HU_min,HU_max = self.HU_min,self.HU_max
        mask_path,image_path,lungmask_path = self.df.mask_path[data_idx],self.df.image_path[data_idx],self.df.lung_mask_path[data_idx]
        
        '''
        Load mask & lungmask
        '''
        mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_path))
        mask = np.array(mask>0).astype(np.uint8)
        lungmask = sitk.GetArrayFromImage(sitk.ReadImage(lungmask_path))
        lungmask = np.array(lungmask>0).astype(np.uint8)
        image = sitk.GetArrayFromImage(sitk.ReadImage(image_path))
        
        '''
        crop mask & lungmask based on valid lungmask region
        '''
        bounds = self._GetBound(lungmask)
        mask = self._GetValidRegion(mask,bounds)
        lungmask = self._GetValidRegion(lungmask,bounds)
        
        num_patches = self._CalculateNumPatch(mask.shape)
        sum_of_patch = 0
        while sum_of_patch/float(patch_volume)<0.01:
            patch_indices = self._GeneratePatchIndex(num_patches)
            image_patch,mask_patch = self._GeneratePatchWithIndices(image,mask,patch_indices)
            sum_of_patch = np.sum(mask_patch)
        return image_patch,mask_patch
        
    
    
    def _GeneratePatchWithIndices(self,image,mask,patch_indices):
        
        HU_min,HU_max = self.HU_min,self.HU_max
        '''
        calculate patch bound
        '''
        patch_centers = [int(patch_indices[idx]*self.stride_ratio+self.input_size[idx]/2.0) for idx in range(3)]
        patch_mins = [int(max(self.input_size[idx]/2,patch_centers[idx]-0.5*self.stride_ratio)) for idx in range(len(patch_centers))]
        patch_maxs = [int(min(mask.shape[idx]-self.input_size[idx]/2,patch_centers[idx]+0.5*self.stride_ratio)) for idx in range(len(patch_centers))]
        temp = []
        for idx in range(len(patch_centers)):
            temp.append(np.random.randint(low=patch_mins[idx],high=patch_maxs[idx],size=1)[0])
        
        patch_centers = temp
        
        '''
        crop patch
        '''
        margins = [int(val/2) for val in self.input_size]
        image_patch = image[patch_centers[0]-margins[0]:patch_centers[0]+margins[0],
                     patch_centers[1]-margins[1]:patch_centers[1]+margins[1],
                     patch_centers[2]-margins[2]:patch_centers[2]+margins[2]]
        mask_patch = mask[patch_centers[0]-margins[0]:patch_centers[0]+margins[0],
                     patch_centers[1]-margins[1]:patch_centers[1]+margins[1],
                     patch_centers[2]-margins[2]:patch_centers[2]+margins[2]]
                            
        if self.clip_choice:
            image_patch = np.clip(image_patch,HU_min,HU_max)
        image_patch = (image_patch - HU_min)/float(HU_max - HU_min)
        
        '''
        aug patch if necessary
        '''
        if self.aug:
            if np.random.choice([0,1]):
                image_patch = np.swapaxes(image_patch,1,2)
                mask_patch = np.swapaxes(mask_patch,1,2)
            if np.random.choice([0,1]):
                image_patch = random_add_val(image_patch)
            if np.random.choice([0,1]):
                image_patch,mask_patch = data_augmentation(image_patch,mask_patch,expend_choice=False)
        image_patch,mask_patch = image_patch[np.newaxis,...],mask_patch[np.newaxis,...]
        new_image_patch = image_patch.copy()
        new_mask_patch = mask_patch.copy()
        new_mask_patch = np.array(new_mask_patch>0).astype(np.uint8)
        return new_image_patch,new_mask_patch
        
    
    def _transImage(self,img,label):
        # final_image,final_mask,new_spacing = imagepari_target_shape_respacing(final_image,final_mask,ori_spacing,current_zoom_shape)
        img = img.numpy()
        label = label.numpy()

        trans_list = {}
        ########### flip
        flip_list = []
        cnt = 3
        prob = 0.5
        while random.random() < prob and cnt > 0:
            degree = random.choice([0, 1, 2])
            img = np.flip(img, axis=degree)
            label = np.flip(label, axis=degree)
            cnt = cnt - 1
            flip_list.append(degree)
        trans_list['flip_list'] = flip_list

        ########## swap axis
        if np.random.choice([0,1]):
            img = np.swapaxes(img,2,3)
            label = np.swapaxes(label,2,3)
            trans_list['swap'] = True
        else:
            trans_list['swap'] = False

        ############## zoom
        '''
        todo: add this part
        '''

        copy_img = img.copy()
        copy_label = label.copy()
        return copy_img,copy_label,trans_list



    def _transWithFunc(self,imgs,func_list):
        # print ('length of imgs is ',len(imgs))
        outputs = []
        if func_list['swap']:
            for img in imgs:
                new_img = np.swapaxes(img,1,2)
                outputs.append(new_img)
        else:
            outputs = [val for val in imgs]
        for degree in func_list['flip_list']:
            for img_idx,img in enumerate(outputs):
                new_img = np.flip(img,axis=degree)
                outputs[img_idx] = new_img
        outputs = [val.copy() for val in outputs]
        # print ('length of outputs',outputs[0].shape)
        return outputs