# -*- coding:utf-8 -*- import os import sys import glob import math import random import numpy as np import pandas as pd import SimpleITK as sitk import scipy.ndimage as nd from scipy.ndimage import zoom from keras.utils import to_categorical from torch.utils.data import Dataset from DataAna import GenerateRatios from data_aug import random_crop,cropDataAndPadToTargetShape from data_aug import data_augmentation from collections import Counter from ImagePro import ShowEdges,HFFilter from skimage import exposure,data from matplotlib import pyplot as plt class PneuDataSet(Dataset): def __init__(self,csv_paths,**kwargs): super(PneuDataSet,self).__init__() self.df = pd.read_csv(csv_paths[0]) self.cls_index = kwargs.get('cls_index') self.cls_types = kwargs.get('cls_types',None) self.num_class = kwargs.get('num_class') self.HU_min = kwargs.get('HU_min') self.HU_max = kwargs.get('HU_max') self.aug = kwargs.get('aug',False) self.input_shape = kwargs.get('input_shape') self.smooth_val = kwargs.get('smooth_val',0) self.patch_shape = kwargs.get('patch_shape') self.training_mode = kwargs.get('training_mode',2) self.mode = kwargs.get('mode','training') self.model_mode = kwargs.get('model_mode','basic') self.n_input_channel = kwargs.get('n_input_channel',1) self.regression_choice = kwargs.get('regression_choice',False) self.indices = np.arange(len(self.df)) random.shuffle(self.indices) print ('='*60) print ('length of df is ',len(self.df)) def __getitem__(self,index): if self.mode=='training': data_idx = self.indices[index%len(self.indices)] else: data_idx = index if self.training_mode==2: base_image,base_label = self._get_single_case_V2(data_idx) net_input_imgs = [base_image] net_input_labels = [base_label] elif self.training_mode==1: base_image,base_label = self._get_single_case(data_idx) net_input_imgs = base_image net_input_labels = base_label else: base_image,base_label =self._get_single_case_V3(data_idx) net_input_imgs = base_image net_input_labels = base_label return net_input_imgs,net_input_labels,net_input_labels def _get_single_case_V2(self,index): data_idx = index img_path_01,img_path_02 = self.df.image_path_prem[data_idx],self.df.image_path_deux[data_idx] label_01,label_02,label_03 = self.df.label_prem[data_idx],self.df.label_deux[data_idx],self.df.label_troi[data_idx] if np.random.choice([0,1]): path = img_path_01 label = label_01 else: path = img_path_02 label = label_02 img = np.load(path) img = np.squeeze(self._process_single_patch(img))[np.newaxis,...] output_img = img.copy() return output_img,label def _get_single_case_V3(self,index): data_idx = index img_path_01,img_path_02 = self.df.image_path_prem[data_idx],self.df.image_path_deux[data_idx] label_01,label_02,label_03 = self.df.label_prem[data_idx],self.df.label_deux[data_idx],self.df.label_troi[data_idx] if np.random.choice([0,1]): img_path_01,img_path_02 = img_path_02,img_path_01 label_01,label_02 = label_02,label_01 labels = [label_01,label_02] labels = [int(val) for val in labels] if '.npy' in img_path_01: img01 = np.load(img_path_01) else: img01 = sitk.GetArrayFromImage(sitk.ReadImage(img_path_01)) if '.npy' in img_path_02: img02 = np.load(img_path_02) else: img02 = sitk.GetArrayFromImage(sitk.ReadImage(img_path_02)) img01 = self._process_single_patch(img01) img02 = self._process_single_patch(img02) imgs = np.concatenate([img01,img02],axis=0) return imgs,label_03 def _get_single_case(self,index): data_idx = index img_path_01,img_path_02 = self.df.image_path_prem[data_idx],self.df.image_path_deux[data_idx] label_01,label_02,label_03 = self.df.label_prem[data_idx],self.df.label_deux[data_idx],self.df.label_troi[data_idx] if np.random.choice([0,1]): img_path_01,img_path_02 = img_path_02,img_path_01 label_01,label_02 = label_02,label_01 labels = [label_01,label_02,label_03] labels = [int(val) for val in labels] if '.npy' in img_path_01: img01 = np.load(img_path_01) else: img01 = sitk.GetArrayFromImage(sitk.ReadImage(img_path_01)) if '.npy' in img_path_02: img02 = np.load(img_path_02) else: img02 = sitk.GetArrayFromImage(sitk.ReadImage(img_path_02)) img01 = self._process_single_patch(img01) img02 = self._process_single_patch(img02) imgs = np.concatenate([img01,img02],axis=0) return imgs,labels def _process_single_patch(self,image_single): if self.aug: delta_range=5 deltas = np.random.randint(low=-delta_range,high=delta_range,size=3) else: deltas = [0 for _ in range(3)] deltas = [0,0,0] image_single = random_crop(image_single,input_shape=self.patch_shape,deltas=deltas,aug=self.aug) # image_single = cropDataAndPadToTargetShape(image_single,self.patch_shape,deltas=deltas,aug=self.aug,pad_val=self.HU_min) image_single = np.clip(image_single,self.HU_min,self.HU_max) image_single = (image_single-self.HU_min)/float(self.HU_max-self.HU_min) if self.aug: image_single = data_augmentation([image_single],prob=0.7)[0] # if np.random.choice([0,1]): # image_single = np.swapaxes(image_single,1,2) return image_single[np.newaxis,np.newaxis,...] def __len__(self): if self.mode == 'training' and self.training_mode!=1: return len(self.indices)*2 else: return len(self.indices) def _labelSmooth(self,label): if label==0: label += self.smooth_val else: label -= self.smooth_val return label def _GenerateRatios(self): labels = self.df['label_troi'].values from collections import Counter counts = Counter(labels) print ('counts is ',counts) ratios = [counts[1],counts[0]] print ('ratios before is',ratios) sums = sum(ratios) print ('sums is ',sums) ratios = [val/float(sums) for val in ratios] return ratios