import os
import sys
import glob
import numpy as np
import pandas as pd
import SimpleITK as sitk
from tqdm import tqdm

from form_process import *
from image_process import *

class ClsPatchCrop(object):
    def __init__(self,**kwargs):
        self.df = LoadDFs(kwargs.get('paths'))
        self.spacing = kwargs.get('spacing',None)
        self.patch_size = kwargs.get('patch_size',96)
        self.uid_save_paths = kwargs.get('uid_save_paths')
        self.base_path = kwargs.get('base_path')
        self.folder_name = kwargs.get('folder_name')
        self.modes = kwargs.get('kws')
        self.split_ratios = kwargs.get('split_ratios')
        self.image_base_path = kwargs.get('image_base_path','/fileser/CT_RIB/data/image/res0/')
        self.coord_kws = kwargs.get('coord_kws',['image_path','coordZ','coordY','coordX','diameter_z','diameter_y','diameter_x','diameter'])
        self.type_kws = kwargs.get('type_kws',['frac_type','poroma'])
      
        self.spacing_kws = kwargs.get('spacing_kws')
        print (' self.uid_save_paths ', self.uid_save_paths ,self.folder_name)


    def __call__(self):
        self._GeneratePath()
        self._SplitDataset()
        self._GenerateData()

    def _GeneratePath(self):
        '''
        用于生成 uid存储路径和图像存储路径
        '''
        self.uid_path = os.path.join(self.uid_save_paths,self.folder_name)
        if not os.path.exists(self.uid_path):
            os.makedirs(self.uid_path)
        self.data_save_base_path = os.path.join(self.base_path,self.folder_name)
        if not os.path.exists(self.data_save_base_path):
            os.makedirs(self.data_save_base_path)
        self.uid_save_paths = [os.path.join(self.uid_path,'%s_uid.txt'%val) for val in self.modes]

    def _imagePathReplace(self,df):
        if 'image_path' in df.columns:
            image_paths = df['image_path'].drop_duplicates().values
            map_dict = {}
            for path in image_paths:
                filename = path.split('/')[-1]
                new_path = '%s/%s'%(self.image_base_path,filename)
                map_dict[path] = new_path
            df['image_path'] = df['image_path'].replace(map_dict)
        else:
            map_dict = {}
            uids = df['uid'].drop_duplicates().values
            for uid in uids:
                new_path = '%s/%s.nii.gz'%(self.image_base_path,uid)
                map_dict[uid] = new_path
            df['image_path'] = df['uid'].replace(map_dict)
        return df

    def _GenerateSingleRecord(self,df,uid, id=None):
        if id is not None:
            print(id)
        self._imagePathReplace(df)
        pad_val = -1024.0
        output_dim = [self.patch_size for _ in range(3)]
        spacing=[self.spacing for _ in range(3)]
        result = GeneratePatchFromDataFrame(uid,df,output_dim,self.coord_kws,self.type_kws,pad_val,spacing=spacing,kw1=self.spacing_kws[0],kw2=self.spacing_kws[1])
        data_list,info_list = result[:2]
        return data_list,info_list

    def _GenerateDataNoMulti(self,df):
        uids = df['uid'].drop_duplicates().values
        pbar = tqdm(uids)
        final_data_list,final_info_list = 0,0
        for uid in pbar:
            current_df = df[df['uid']==uid]
            current_df.reset_index(inplace=True)
            data_list,info_list = self._GenerateSingleRecord(current_df,uid)
            if data_list is None:
                continue
            try:
                if(type(final_info_list)==int):
                    final_data_list = np.array(data_list)
                    final_info_list = np.array(info_list)
                else:
                    final_data_list = np.concatenate([final_data_list,data_list])
                    final_info_list = np.concatenate([final_info_list,info_list])
            except:
                continue
        return final_data_list,final_info_list
            
    
    def _GenerateDataWithMulti(self, df):
        from multiprocessing import Pool
        pool = Pool(30)
        uids = df['uid'].drop_duplicates().values
        final_data_list, final_info_list = 0, 0
        pool_returns = []
        for id, uid in enumerate(uids):
            current_df = df[df['uid'] == uid]
            current_df.reset_index(inplace=True)
            # data_list, info_list = self._GenerateSingleRecord(current_df, uid)
            pool_returns.append(pool.apply_async(self._GenerateSingleRecord, (current_df, uid, id)))
        pool.close()
        pool.join()
        pbar = tqdm(pool_returns)
        for pool_return in pbar:
            data_list, info_list = pool_return.get()
            if data_list is None:
                continue
            try:
                if (type(final_info_list) == int):
                    final_data_list = np.array(data_list)
                    final_info_list = np.array(info_list)
                else:
                    final_data_list = np.concatenate([final_data_list, data_list])
                    final_info_list = np.concatenate([final_info_list, info_list])
            except:
                continue
        return final_data_list, final_info_list

    def _GenerateData(self):
        for idx in range(3):
            df = self.dfs[idx]
            if df is None:
                continue
            spacing = 'raw' if self.spacing is None else self.spacing
            img_path = '%s/%s_spacing_data_%s.npy'%(self.data_save_base_path,str(spacing),self.modes[idx])
            info_path = '%s/%s_spacing_info_%s.npy'%(self.data_save_base_path,str(spacing),self.modes[idx])
            data_list, info_list = self._GenerateDataWithMulti(df)
            np.save(img_path,data_list)
            np.save(info_path,info_list)
            ###################### np.save()
    
    def _SplitDataset(self):
        uids = self.df['uid'].drop_duplicates().values

        uid_list = split_dataset(uids,self.split_ratios)
        for current_uids in uid_list:
            if current_uids is not None:
                print ('='*60)
                print ('length of current_uids %d'%(len(current_uids)))
                print ('='*60)
        
        for idx in range(len(uid_list)):
            if uid_list[idx] is not None:
                WriteTxt(uid_list[idx],self.uid_save_paths[idx])

        self.train_uids,self.val_uids,self.test_uids = uid_list
        self.train_df = self.df[self.df['uid'].isin(self.train_uids)]
        if self.val_uids is not None:
            self.val_df = self.df[self.df['uid'].isin(self.val_uids)]
            if self.test_uids is not None:
                self.test_df = self.df[self.df['uid'].isin(self.test_uids)]
            else:
                self.test_df = None
        else:
            self.val_df,self.test_df = None,None
        self.dfs = [self.train_df,self.val_df,self.test_df]


if __name__=="__main__":


    json_path = "./files/NoduleCls/ForDataPrep/prep/20210107Test/para.json"
    params = ReadJson(json_path)
    for str_kw in ['kws','coord_kws','type_kws','spacing_kws']:
        params[str_kw] = eval(params[str_kw])
    data_path = params['csv_path']
    data_json = ReadJson(data_path)
    paths = [record['df_path'] for record in data_json]
    params['paths'] = paths

    print ('params is ',params)

    patch_cropper = ClsPatchCrop(**params)
    patch_cropper()