Commit 329f1f6c authored by exp's avatar exp

分类训练

parents
Pipeline #488 canceled with stages
import numpy as np
#from matplotlib_inline import backend_inline
from matplotlib import pyplot as plt
import os
import torch
class Accumulator:
"""在n个变量上累加"""
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class Animator:
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
figsize=(5, 5)):
if legend is None:
legend = []
#仅仅是为了将格式改为scg便于在juoyter上显示
#self.use_svg_display()
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
#print(self.axes)
if nrows * ncols == 1:
self.axes = [self.axes, ]
print('成功')
print(self.axes)
#使用lambda函数捕获参数
print(self.axes[0])
self.config_axes = lambda: self.set_axes(
axes=self.axes[0], xlabel=xlabel, ylabel=ylabel,
xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale, legend=legend)
self.X, self.Y, self.fmts = None, None, fmts
'''
def use_svg_display(self):
"""
使用svg格式在jupyter中显示绘图
"""
backend_inline.set_matplotlib_formats('svg')
'''
def set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
#print(axes)
axes.set_xlabel(xlabel), axes.set_ylabel(ylabel)
axes.set_xscale(xscale), axes.set_yscale(yscale)
axes.set_xlim(xlim), axes.set_ylim(ylim)
if legend:
axes.legend(legend)
axes.grid()
def add(self, x, y):
#向图表中添加多个数据点
#hasattr(object, name) 用于判断对象是否包含对应的属性
if not hasattr(y, '__len__'):
y = [y]
n = len(y)
if not hasattr(x, '__len__'):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
#用于清空axes对象的所有绘图元素
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
print(x,' ', y)
self.axes[0].plot(x, y, fmt)
self.config_axes()
print('更新一次')
#self.fig.canvas.draw()
#self.fig.canvas.flush_events()
#plt.draw()
plt.pause(2)
#npy文件中的数据加载出来
def load_npy_to_data(input_file=None):
original_data = np.load(input_file)
return original_data
#将一个固定的npy数组以图片的形式显示出来
def show_img():
input_file = r'/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/npy_data/cls_2047/495.npy'
data = load_npy_to_data(input_file)
#print(data)
data = data[29].astype(np.float32)
print(data)
plt.imshow(data, cmap='gray')
plt.show()
#索引为24的数据是中心面
def show_img_2d():
input_file = r'/home/lung/ai-project/cls_train/data/train_data/plus_0512/npy_data/cls_2021/1893_10.npy'
data = load_npy_to_data(input_file)
data = data.astype(np.float32)
plt.imshow(data, cmap='gray')
plt.show()
#测试数据
def run():
test = [[0.1, 0.2], [0.3, 0.6]]
thred = 0.5
test = torch.tensor(test)
result = test > thred
print(result.float())
#print(test > thred)
def test_sum():
test = [[[1, 1, 2], [2, 3, 4]],
[[2, 0, 0], [9, 9, 9]]]
test = np.array(test)
print(np.sum(test, axis=0))
#获取多维数组中值大于指定值的索引
def test():
mask = [[[ True, True, True],
[False, False, False],
[ True, True, True]],
[[ True, False, True],
[False, False, False],
[ True, True, True]]]
mask = np.array(mask)
indices = np.asarray(np.where(mask == 1))
print(indices)
print(indices.min(axis=1))
print(indices.max(axis=1))
#将两个图像取差值,最后输出结果图像查看
def check_image():
for index in range(22, 37):
data_1 = load_npy_to_data(f'/home/lung/project/ai-project/cls_train/log/npy/01/{index}.npy')
data_2 = load_npy_to_data(f'/home/lung/project/ai-project/cls_train/log/npy/03/{index}.npy')
#data = np.where(data_1 == data_2, data_1, 0)
data = data_2 - data_1
plt.imsave(f'/home/lung/project/ai-project/cls_train/log/image/temp_02/{index}.png', data, cmap='gray')
if __name__ == '__main__':
#test_sum()
#check_image()
show_img()
import random
import itertools
import numpy as np
def generate_general_indexs():
indexes = [[0, 2],
[1, 3],
[2, 0],
[3, 1],
[4, 6],
[5, 7],
[6, 4],
[7, 5]]
return indexes
def generate_general_permute_keys():
rotate_y = 0
transpose = 0
keys = [((rotate_y, 0), 0, 0, 0, transpose),
((rotate_y, 1), 0, 0, 0, transpose),
((rotate_y, 0), 0, 1, 1, transpose),
((rotate_y, 1), 0, 1, 1, transpose),
((rotate_y, 0), 1, 0, 0, transpose),
((rotate_y, 1), 1, 0, 0, transpose),
((rotate_y, 0), 1, 1, 1, transpose),
((rotate_y, 1), 1, 1, 1, transpose)]
return keys
def random_permute_key():
"""
Generate and randomly selects a permute key
"""
return random.choice(list(generate_general_permute_keys()))
def generate_permutation_keys():
"""
This function returns a set of "keys" that represent the 48 unique rotations &
reflections of a 3D matrix.
Each item of the set is a tuple:
((rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)
As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
transposed.
48 unique rotations & reflections:
https://en.wikipedia.org/wiki/Octahedral_symmetry#The_isometries_of_the_cube
"""
return set(itertools.product(
itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))
def random_permutation_key():
"""
Generates and randomly selects a permutation key. See the documentation for the
"generate_permutation_keys" function.
"""
return random.choice(list(generate_permutation_keys()))
def augment_data(data):
padded_data = -1000 * np.ones((256,256,256))
padded_data[104:152, :, :] = data
data = padded_data.transpose(1, 2, 0)
return data
def permute_data(data, key):
"""
Permutes the given data according to the specification of the given key. Input data
must be of shape (n_modalities, x, y, z).
Input key is a tuple: (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)
As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
transposed.
"""
data = np.copy(data)
#print(data.shape)
(rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose = key
if rotate_y != 0:
data = np.rot90(data, rotate_y, axes=(1, 3))
if rotate_z != 0:
data = np.rot90(data, rotate_z, axes=(2, 3))
if flip_x:
data = data[:, ::-1]
if flip_y:
data = data[:, :, ::-1]
if flip_z:
data = data[:, :, :, ::-1]
if transpose:
for i in range(data.shape[0]):
data[i] = data[i].T
return data
def random_permutation_data(data):
key = random_permute_key()
return permute_data(data, key)
\ No newline at end of file
import os
import sys
import json
import numpy as np
import pandas as pd
import glob
import torch
import collections
import random
from torchvision import transforms
from matplotlib import pyplot as plt
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from cls_utils.utils import check_and_makedirs
from cls_utils.augement import random_permutation_data, augment_data
from net.component_c import get_data
#读取指定json_path路径下的json文件
def load_json(json_path):
with open(json_path) as f:
cfg = json.load(f)
return cfg
#将新数据写入指定的csv文件
def save_supplement_data_csv(content, output_file):
data = {}
data['npy_path'] = [content[0]]
data['label'] = [content[1]]
if os.path.exists(output_file) is False:
check_and_makedirs(output_file=output_file)
data = pd.DataFrame(data)
#mode='a'为追加,index为索引号,header为标题
data.to_csv(output_file, mode='a', index=False, header=False)
#将numpy数组保存带npy文件中
def save_data_to_npy(original_data=None, output_file=None):
check_and_makedirs(output_file)
np.save(output_file, original_data)
print('文件:', output_file, ' 保存成功')
#npy文件中的数据加载出来
def load_npy_to_data(input_file=None):
original_data = np.load(input_file)
return original_data
#读取指定csv文件
def load_data_csv(input_file):
ids = pd.read_csv(input_file, header=None, encoding='utf-8')
return ids
#将数据保存到csv文件内
def save_data_csv(content, output_file):
data = pd.DataFrame(content)
data.to_csv(output_file, index=False, header=None, encoding='utf-8')
def save_data_npy(npy_path, file_prefix, data=None, truth=None, affine=None, patch_index=None):
check_and_makedirs(os.path.join(npy_path, file_prefix))
if data is not None:
np.save(os.path.join(npy_path, file_prefix + '_data.npy'),
data.astype(np.float32))
if truth is not None:
np.save(os.path.join(npy_path, file_prefix + '_truth.npy'),
truth.astype(np.uint8))
if affine is not None:
np.save(os.path.join(npy_path, file_prefix + '_affine.npy'),
affine.astype(np.float32))
if patch_index is not None:
np.save(os.path.join(npy_path, file_prefix + '_patch_index.npy'),
patch_index.astype(np.int16))
def load_all_dicom_file(dirname, prefix='' ,postfix=''):
file_path = os.path.join(dirname, prefix + '.' + postfix)
all_file_path = glob.glob(file_path)
return all_file_path
#将一个指定csv文件中的所有的label_id找出来
def train_all_label_id(csv_path):
data = load_data_csv(csv_path)
all_labels = data[0].tolist()
all_labels = [label.split('/')[1] for label in all_labels]
train_all_labels = [int(label.split('.')[0]) for label in all_labels]
return train_all_labels
#通过id来获取相对应的类别
def find_lable_by_id(lable_id, subject_all_df):
row = subject_all_df[subject_all_df['lable_id'] == lable_id]
lable = row.iloc[0, 1]
file_path = 'cls_' + str(lable) + '/' + lable_id + '.npy'
return file_path
#将一出错的数据直接添加到训练集中,生成新的数据不平衡的csv文件, positive=True代表增加正类样本
def add_label_ids(label_ids, csv_path, positive=True, cls_name=''):
add_lable = 1 if positive else 0
original_csv_path = os.path.join(csv_path, '08', cls_name, 'train.csv')
subject_all_csv_path = os.path.join(csv_path, 'subject_all.csv')
print(original_csv_path)
data = load_data_csv(original_csv_path)
subject_all_df = pd.read_csv(subject_all_csv_path, header=None, names=['lable_id', 'lable'])
#对第一列中的字符串进行切分,提取lable_id
def get_lable_id(path_str):
lable_file = path_str.split('/')[1]
lable_id = lable_file.split('.')[0]
return lable_id
subject_all_df['lable_id'] = subject_all_df['lable_id'].apply(get_lable_id)
for label_id in label_ids:
print(label_id)
result = find_lable_by_id(label_id, subject_all_df)
file_paths = [find_lable_by_id(label_id, subject_all_df) for label_id in label_ids]
#find_lable_by_id(label_ids[0], subject_all_df)
#label_ids = ['cls_2047/'+str(label_id)+'.npy' for label_id in label_ids]
node1_data = data[data[1] == 1]
node2_data = data[data[1] == 0]
add_data = pd.DataFrame(file_paths)
add_data[1] = add_lable
node_all_data = pd.concat([node1_data, node2_data, add_data])
"""node1_name = '2047'
node2_name = '1016'
cls_unite_csv_path = os.path.join(csv_path,
'cls_' + node1_name + '_' + node2_name,
'train.csv')
check_and_makedirs(cls_unite_csv_path)"""
save_data_csv(node_all_data, original_csv_path)
#将一个list中的label_id替换到一个指定csv文件中相对应的类型中
def replace_label_ids(label_ids, csv_path, tabel_id):
original_csv_path = os.path.join(csv_path, 'cls_2047_1016_1', 'train.csv')
print(original_csv_path)
data = load_data_csv(original_csv_path)
label_ids = ['cls_2047/'+str(label_id)+'.npy' for label_id in label_ids]
node1_data = data[data[1] == 1]
node2_data = data[data[1] == 0]
node1_data.iloc[:len(label_ids), 0] = label_ids
node_all_data = pd.concat([node1_data, node2_data])
node1_name = '2047'
node2_name = '1016'
cls_unite_csv_path = os.path.join(csv_path,
'cls_' + node1_name + '_' + node2_name,
'train.csv')
check_and_makedirs(cls_unite_csv_path)
save_data_csv(node_all_data, cls_unite_csv_path)
def create_cls_train_last_3d(node_times, tabel_id=None, csv_path=None, csv_name=None, pretrain_csv_path=''):
#读取之前训练的csv文件
pretrain_data = load_data_csv(pretrain_csv_path)
pretrain_data_labels = pretrain_data[0].tolist()
subject_all_path = os.path.join(csv_path, csv_name)
data = load_data_csv(subject_all_path)
node1_data = data[data[1].isin(node_times[0])]
node1_data = node1_data[~node1_data[0].isin(pretrain_data_labels)]
node2_data = data[data[1].isin(node_times[1])]
node1_data.loc[:, 1] = 1
node2_data.loc[:, 1] = 0
node_all_data = pd.concat([node1_data, node2_data])
node1_name = '1'
#node1_name = str(node_times[0][0]) if len(node_times[0])==1 else '-'.join([str(time) for time in node_times[0]])
node2_name = str(node_times[1][0]) if len(node_times[1])==1 else '-'.join([str(time) for time in node_times[1]])
cls_unite_csv_path = os.path.join(csv_path, tabel_id,
'cls_' + node1_name + '_' + node2_name,
'train.csv')
check_and_makedirs(cls_unite_csv_path)
save_data_csv(node_all_data, cls_unite_csv_path)
def create_cls_train_csv_3d(node_times, tabel_id=None, csv_path=None, csv_name=None, max_len=None):
"""如果max_len不为None,则正负两个类别都取同样个数(max_len)的数据"""
subject_all_path = os.path.join(csv_path, csv_name)
data = load_data_csv(subject_all_path)
node1_data = pd.DataFrame()
for node_time in node_times[0]:
node_data = data[data[1] == node_time]
node_data = node_data[:(len(node_data) * 3)//4]
node1_data = pd.concat([node1_data, node_data])
node2_data = pd.DataFrame()
for node_time in node_times[1]:
node_data = data[data[1] == node_time]
node_data = node_data[:(len(node_data) * 3)//4]
node2_data = pd.concat([node2_data, node_data])
"""node1_data = data[data[1].isin(node_times[0])]
node2_data = data[data[1].isin(node_times[1])]"""
node1_data.loc[:, 1] = 1
node2_data.loc[:, 1] = 0
# node_all_data = pd.concat([node1_data, node2_data])
node_all_data = node1_data
node1_name = '20241112'
#node1_name = str(node_times[0][0]) if len(node_times[0])==1 else '-'.join([str(time) for time in node_times[0]])
node2_name = str(node_times[1][0]) if len(node_times[1])==1 else '-'.join([str(time) for time in node_times[1]])
cls_unite_csv_path = os.path.join(csv_path, tabel_id,
'cls_' + node1_name + '_' + node2_name,
'train.csv')
check_and_makedirs(cls_unite_csv_path)
save_data_csv(node_all_data, cls_unite_csv_path)
print(f"save_data_csv: {cls_unite_csv_path}")
def create_cls_train_all_csv(node_times, tabel_id=None, csv_path=None, csv_name=None):
subject_all_path = os.path.join(csv_path, csv_name)
data = load_data_csv(subject_all_path)
node1_data = data[data[1].isin(node_times[0])]
node2_data = data[data[1].isin(node_times[1])]
node1_data.loc[:, 1] = 1
node2_data.loc[:, 1] = 0
node_all_data = pd.concat([node1_data, node2_data])
node1_name = str(node_times[0][0]) if len(node_times[0])==1 else '-'.join([str(time) for time in node_times[0]])
node2_name = str(node_times[1][0]) if len(node_times[1])==1 else '-'.join([str(time) for time in node_times[1]])
cls_unite_csv_path = os.path.join(csv_path, tabel_id,
'cls_' + node1_name + '_' + node2_name,
'train.csv')
check_and_makedirs(cls_unite_csv_path)
#node_all_data.to_csv(cls_unite_csv_path, index=False, header=None, encoding='utf-8')
save_data_csv(node_all_data, cls_unite_csv_path)
def create_cls_train_csv(node_times, node2=None, csv_path=None, csv_name=None, tabel_id='', node1_end=False, node2_end=False, min=0):
"""
从subject_all.csv文件内找出指定类别的数据,生成数据集对应的csv文件
Parameters:
node_times: node_time[0]表示要进行二分类的第一个类别,除此以外都为另一个类别
"""
#获取subject_all.csv文件所在的路径
subject_all_path = os.path.join(csv_path, csv_name)
data = load_data_csv(subject_all_path)
data = data.sample(frac=1, replace=False)
""" indices = data.index.to_numpy()
indices = np.random.shuffle(indices)
data = data.reindex(indices).reset_index(drop=True)"""
#选择数量少的最为训练数据量的标准
node1_data = data[data[1].isin(node_times[0])]
node2_data = data[data[1].isin(node_times[1])]
indices1, indices2 = node1_data.index.to_numpy(), node2_data.index.to_numpy()
np.random.shuffle(indices1)
np.random.shuffle(indices2)
node1_data = node1_data.reindex(indices1).reset_index(drop=True)
node2_data = node2_data.reindex(indices2).reset_index(drop=True)
if len(node1_data) < len(node2_data):
min = len(node1_data)
max_len = len(node2_data)
min_is_first = True
else:
min = len(node2_data)
max_len = len(node1_data)
min_is_first = False
#指定数据量
for i in range(int(max_len/min) + 1):
only_one_data = node1_data if min_is_first else node2_data
if i == int(max_len/min) + 1:
second_data = node1_data[i*min:] if not min_is_first else node2_data[i*min:]
else:
second_data = node1_data[i*min:(i+1)*min] if not min_is_first else node2_data[i*min:(i+1)*min]
only_one_data.loc[:, 1] = 1 if min_is_first else 0
second_data.loc[:, 1] = 0 if min_is_first else 1
node_all_data = pd.concat([only_one_data, second_data])
#将数据保存
node1_name = str(node_times[0][0]) if len(node_times[0])==1 else '-'.join([str(time) for time in node_times[0]])
node2_name = str(node_times[1][0]) if len(node_times[1])==1 else '-'.join([str(time) for time in node_times[1]])
cls_unite_csv_path = os.path.join(csv_path, tabel_id,
'cls_' + node1_name + '_' + node2_name,
'cls_' + node1_name + '_' + node2_name + '_' + str(i+1),
'train.csv')
check_and_makedirs(cls_unite_csv_path)
#node_all_data.to_csv(cls_unite_csv_path, index=False, header=None, encoding='utf-8')
save_data_csv(node_all_data, cls_unite_csv_path)
"""node1_data = node1_data.iloc[9*min :9*min+min]
#node2_data = node2_data.head(min) if node2_end is False else node2_data.tail(min)
node1_data = node1_data.head(min) if node1_end is False else node1_data.tail(min)
#data = data[data[1].isin(node_times)]
print('总数据个数:', len(node1_data)+len(node2_data))
#从数据中随机抽取0.5的数据
#data = data.sample(frac=0.5)
#print('抽取后的数据个数: ', len(data))
#将node_times中的第一个类别确定为正类
node1_data.loc[:, 1] = 1
node2_data.loc[:, 1] = 0
node_all_data = pd.concat([node1_data, node2_data])
print("正类个数:", len(node1_data))
print("负类个数:", len(node2_data))
#将数据保存
if len(node_times) == 2:
node2 = node_times[1]
cls_unite_csv_path = os.path.join(csv_path,
'cls_' + str(node_times[0]) + '_' + str(node2),
'cls_' + str(node_times[0]) + '_' + str(node2),
'train.csv')
check_and_makedirs(cls_unite_csv_path)
#node_all_data.to_csv(cls_unite_csv_path, index=False, header=None, encoding='utf-8')
save_data_csv(node_all_data, cls_unite_csv_path)"""
def get_data_from_file(npy_path, subject_id):
"""
从指定的路径下加载一个文件
"""
subject_npy_path = os.path.join(npy_path, str(subject_id))
#print(subject_npy_path)
data = load_npy_to_data(subject_npy_path)
return data
def get_data_and_label(npy_path, subject_id, is_2d=False, augment=False, permute=False):
data = get_data_from_file(npy_path, subject_id)
data = torch.from_numpy(data)
#将数据平移
lower_bound, upper_bound = -10, 10
shift_x, shift_y = random.randint(lower_bound, upper_bound), random.randint(lower_bound, upper_bound)
new_array = np.full(data.shape, -1000, dtype=float)
start_x, start_y = max(0, shift_x), max(0, shift_y)
end_x, end_y = min(256, 256 + shift_x), min(256, 256 + shift_y)
#print(start_x, start_y, end_x, end_y)
new_array[:, start_x:end_x, start_y:end_y] = data[:, max(0, -shift_x):min(256, 256-shift_x), max(0, -shift_y):min(256, 256-shift_y)]
data = new_array
#print(data.shape)
if is_2d:
data = data.unsqueeze(0)
result_data = data
#在这里将数据进行翻转
transform_data = transforms.Compose([transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), fill=-1000),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), fill=-1000)])
for _ in range(7):
transforms_new_data = transform_data(data)
result_data = torch.cat((result_data, transforms_new_data), dim=0)
return result_data.numpy()
else:
if augment:
#data = augment_data(data)
data = data[np.newaxis]
if permute:
"""if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
raise ValueError('To utilize permutations, data array must be in 3d cube shape with all dimentions having '
'the same length')"""
data = random_permutation_data(data)
#data = data.transpose(0, 3, 1, 2)
#data = data[:, 104:152, :, :]
#print(data.shape)
return data
def save_model_ckpt(ckpt_path, ckpt_name, new_ckpt_name, model):
"""
将之前训练好的模型中保存下来的参数重新提取成一个新的模型参数文件,要保证model_names中的元素
在模型参数文件中存在相对应的键值
"""
pretrain_ckpt_path = os.path.join(ckpt_path, ckpt_name)
new_ckpt_path = os.path.join(ckpt_path, new_ckpt_name)
#加载ckpt文件数据
model_param = torch.load(pretrain_ckpt_path)
new_model_param = dict()
#新建一个collections.OrderedDict()对象,用于存放新的参数数据
state_dict = collections.OrderedDict()
for name, param in model.named_parameters():
if name in model_param['state_dict'].keys() is False:
raise ValueError('Ensure that the mmodel structure is consistent with the parameter file model structure')
state_dict[name] = model_param['state_dict'][name]
#将数据输出
new_model_param['state_dict'] = state_dict
#print(new_model_param['state_dict']['diff_classifier.weight'])
torch.save(new_model_param, new_ckpt_path)
#将训练结束之后的数据在知道哪个路径下生成图片
def save_summary_data(summary_trains=None, summary_valids=None, result_img_path=None):
#首先保证该文件夹是否存在
check_and_makedirs(result_img_path)
plt.figure()
plt.xlim(1, len(summary_trains))
plt.ylim(0, max(summary_trains))
plt.xlabel('epoch')
plt.ylabel('loss')
x = np.arange(1, len(summary_trains)+1, 1)
plt.plot(x, summary_trains, 'b--', label='train loss')
if summary_valids != None:
plt.plot(x, summary_valids, 'r-', label='valid loss')
plt.legend()
plt.savefig(result_img_path, dpi=120)
#plt.show()
#不使用该函数
#将模型训练后的参数名称改成相对应的
def save_new_ckpt(ckpt_path, ckpt_name, new_ckpt_name):
pretrain_ckpt_path = os.path.join(ckpt_path, ckpt_name)
new_ckpt_path = os.path.join(ckpt_path, new_ckpt_name)
#将要进行处理的参数文件加载进来
model_param = torch.load(pretrain_ckpt_path)
new_model_param = dict()
#新建一个collections.OrderedDict()对象,用于存放新的参数数据
state_dict = collections.OrderedDict()
for key in model_param['state_dict'].keys():
#将名字进行改变
name = key.replace('module.', '')
state_dict[name] = model_param['state_dict'][key]
print(name)
new_model_param['state_dict'] = state_dict
torch.save(new_model_param, new_ckpt_path)
#读取一个csv文件,将csv文件中的标签为1或者0的label_id找出来
def get_csv_all_label_ids_bylabel(csv_path, node_times, label=1):
node_times = [str(node_time) for node_time in node_times]
all_datas = load_data_csv(csv_path)
all_datas_label = all_datas[all_datas[1] == label][0].tolist()
all_nodes_list = [x.split('/')[0].split('_')[1] for x in all_datas_label]
all_label_ids_list = [x.split('/')[1].split('.')[0] for x in all_datas_label]
result_label_ids = []
for index in range(len(all_nodes_list)):
if all_nodes_list[index] in node_times:
result_label_ids.append(int(all_label_ids_list[index]))
return result_label_ids
#-------------------------------------------------------------------
#测试函数
def test_load_data():
npy_path = './cls_train/data/train_data/plus_0815/npy_data'
subject_id = 'cls_1016/2268_49.npy'
data = get_data_and_label(npy_path=npy_path, subject_id=subject_id)
print(data.shape)
#对与训练好的参数文件进行处理,保留模型需要的,不需要的就将其删除
def test_save_ckpt():
ckpt_path = './cls_train/best_cls'
ckpt_name = 'train_test.ckpt'
new_ckpt_name = 'test.ckpt'
#save_model_ckpt(ckpt_path=ckpt_path, ckpt_name=ckpt_name, new_ckpt_name=new_ckpt_name, model=model)
save_new_ckpt(ckpt_path=ckpt_path, ckpt_name=ckpt_name, new_ckpt_name=new_ckpt_name)
#读取ckpt文件,观察其参数
def read_ckpt():
ckpt_path = './cls_train/best_cls'
ckpt_name = 'train_test.ckpt'
pretrain_ckpt_path = os.path.join(ckpt_path, ckpt_name)
model_param = torch.load(pretrain_ckpt_path)
#print(type(model_param['state_dict']))
for key in model_param['state_dict'].keys():
print(model_param['state_dict'][key].shape)
#print(model_param['state_dict']['diff_classifier.weight'])
#在subject_all.csv中删除指定node_time的数据
def delete_node_csv(node_time):
csv_path = os.path.join("./cls_train/data/train_data/plus_0617", "subject_all_csv", "subject_all.csv")
data = load_data_csv(csv_path)
data = data[data[1] != node_time]
save_data_csv(data, csv_path)
if __name__ == '__main__':
#delete_node_csv(node_time=2046)
"""csv_path = '/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/test/cls_1_5001-6001/train.csv'
train_all_id = train_all_label_id(csv_path)
print(train_all_id)"""
#test_load_data()
"""npy_path = '/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/npy_data'
subject_id = 'cls_2047/432.npy'
data = get_data_and_label(npy_path, subject_id, is_2d=False, augment=True, permute=True)
data = data[0, 29].astype(np.float32)
print(data)
plt.imshow(data, cmap='gray')
plt.show()"""
csv_path = '/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_1016/train.csv'
node_times = [2041]
label = 1
result = get_csv_all_label_ids_bylabel(csv_path, node_times, label)
print(result)
\ No newline at end of file
import os
import numpy as np
"""
Describe:
根据传过来的中心点对数据进行切割
Parameters:
data: 所有的切面数据
crop_start: 在切面中选择出来的大小为(48, 256, 256)区域的三个维度的中心点的坐标
"""
def get_crop_data(data, crop_start, crop_size, mode='constant', constant_values=-1000):
if data is None:
return data
data_shape = np.array(data.shape)
crop_data = data[
max(crop_start[0], 0):min(crop_start[0] + crop_size[0], data_shape[0]),
max(crop_start[1], 0):min(crop_start[1] + crop_size[1], data_shape[1]),
max(crop_start[2], 0):min(crop_start[2] + crop_size[2], data_shape[2])]
# pad = np.zeros((3, 2), np.int)
pad = np.zeros((3, 2), np.int32)
#np.maximum(x1, x2) 用于逐个比较其中元素的大小,最后返回的是较大的形成的数据
#往前需要填充多少数据
pad[:, 0] = np.maximum(-crop_start, 0)
pad[:, 1] = np.maximum(crop_start + crop_size - data_shape, 0)
if not np.all(pad == 0):
#pad表示在每个维度上要填充的长度
if mode == 'edge':
crop_data = np.pad(crop_data, pad, mode=mode)
else:
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
return crop_data
def get_crop_data_padding_2d_opt(ct_data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
# 获取原始数据
data = ct_data.get_raw_image()
data_shape = np.array(data.shape)
crop_size = np.array(crop_size)
crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
crop_data = get_crop_data(data=data,
crop_start=crop_start,
crop_size=crop_size,
mode=mode,
constant_values=constant_values)
return crop_data
"""
只保留一层,将该层数据放置中心点,别的层都填充为-1000
crop_size=[48, 256, 256]
"""
def get_crop_data_padding(ct_data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
data = ct_data.get_raw_image()
data_shape = np.array(data.shape)
crop_size = np.array(crop_size)
z_index = int(select_box[0][0])
crop_start = np.int0((select_box[1:, 0] + select_box[1:, 1])//2 - crop_size[1:] // 2)
crop_data = data[z_index:z_index+1,
max(crop_start[0], 0) : min(crop_start[0] + crop_size[1], data_shape[1]),
max(crop_start[1], 0) : min(crop_start[1] + crop_size[2], data_shape[2])]
# pad = np.zeros((3, 2), np.int)
pad = np.zeros((3, 2), np.int32)
pad[1:, 0] = np.maximum(-crop_start, 0)
pad[1:, 1] = np.maximum(crop_start + crop_size[1:] - data_shape[1:], 0)
if not np.all(pad == 0):
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
result_data = np.full(crop_size, -1000)
result_data[crop_size[0]//2] = crop_data
return result_data
#从当前(1024, 1024)的图片数据中获取到指定坐标区域的数据,如果所在区域不够(256, 256),则需要相对应的补充-1000
#select_box=[[z_index, z_index], [y_min, y_max], [x_min, x_max]]
#crop_size=[256. 256]
def get_crop_data_2d(data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
data_shape = np.array(data.shape)
crop_size = np.array(crop_size)
crop_start = np.int0((select_box[1:, 0] + select_box[1:, 1])//2 - crop_size[:]//2)
crop_data = data[max(crop_start[0], 0) : min(crop_start[0] + crop_size[0], data_shape[0]),
max(crop_start[1], 0) : min(crop_start[1] + crop_size[1], data_shape[1])]
# pad = np.zeros((2, 2), np.int)
pad = np.zeros((2, 2), np.int32)
pad[:, 0] = np.maximum(-crop_start, 0)
pad[:, 1] = np.maximum(crop_start + crop_size[:] - data_shape[:], 0)
if not np.all(pad == 0):
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
"""result_data = np.full(crop_size, -1000)
result_data[crop_size[0]//2] = crop_data"""
result_data = crop_data
return result_data
'''
#将对切面数据继续处理,返回(48, 256, 256)大小的数据
def crop_ct_data(ct_data=None, select_box=None, crop_size=None):
data = ct_data.get_raw_image()
#print(data.shape)
spacing = ct_data.get_raw_spacing()
z_center = int((select_box[0, 0] + select_box[0, 1]) // 2)
crop_size = np.array(crop_size)
#以下操作包含原box的情况下,并z_center处于中间位置
#z_diff=crop_size[0] - (select_box[0, 1] - select_box[0, 0])
z_diff = crop_size[0] - (select_box[0, 1] - select_box[0, 0]) - 1
#所选择的层数小于48
if z_diff > 0:
#新增前面层数
add_front_z = 0
#新增后面层数
add_behind_z = 0
#select_box所对应的中间层距离z_min相差的层数
z_front_num = z_center - select_box[0, 0]
#select_box所对应的中间层距离z_max相差的层数
z_behind_num = select_box[0, 1] - z_center
#这里add_front_z为1或0
add_front_z = min(z_behind_num - z_front_num, z_diff)
#当z_diff != 1 时
if add_front_z < z_diff:
add_front_z = add_front_z + (z_diff - add_front_z) // 2
add_behind_z = z_diff - add_front_z
else:
add_front_z = add_front_z
add_behind_z = 0
select_box[0, 0] = select_box[0, 0] - add_front_z
select_box[0, 1] = select_box[0, 1] + add_behind_z
scale = (1, 1, 1)
#scale = (1, 1, 1) if spacing[0] > 0.5 else (2, 1, 1)
crop_size = np.int0(crop_size * scale)
crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
#根据尺寸大小对数据进行切割
original_data = get_crop_data(data=data, crop_start=crop_start, crop_size=crop_size)
return original_data
'''
def crop_ct_data(ct_data=None, select_box=None, crop_size=None):
data = ct_data.get_raw_image()
#print(data.shape)
spacing = ct_data.get_raw_spacing()
z_center = int((select_box[0, 0] + select_box[0, 1]) // 2)
crop_size = np.array(crop_size)
#以下操作包含原box的情况下,并z_center处于中间位置
#z_diff=crop_size[0] - (select_box[0, 1] - select_box[0, 0])
z_diff = crop_size[0] - (select_box[0, 1] - select_box[0, 0])
#所选择的层数小于48
if z_diff > 0:
#新增前面层数
add_front_z = 0
#新增后面层数
add_behind_z = 0
#select_box所对应的中间层距离z_min相差的层数
z_front_num = z_center - select_box[0, 0]
#select_box所对应的中间层距离z_max相差的层数
z_behind_num = select_box[0, 1] - z_center
if z_behind_num >= z_front_num > 0:
add_front_z = min(z_behind_num - z_front_num, z_diff)
if add_front_z < z_diff:
add_front_z = add_front_z + (z_diff - add_front_z) // 2
add_behind_z = z_diff - add_front_z
elif z_front_num > z_behind_num > 0:
add_behind_z = min(z_front_num - z_behind_num, z_diff)
if add_behind_z < z_diff:
add_behind_z = add_behind_z + (z_diff - add_behind_z) // 2
add_front_z = z_diff - add_behind_z
select_box[0, 0] = select_box[0, 0] - add_front_z
select_box[0, 1] = select_box[0, 1] + add_behind_z
scale = (1, 1, 1)
#scale = (1, 1, 1) if spacing[0] > 0.5 else (2, 1, 1)
crop_size = np.int0(crop_size * scale)
crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
#根据尺寸大小对数据进行切割
original_data = get_crop_data(data=data, crop_start=crop_start, crop_size=crop_size)
return original_data
#将自动分割出来的数据都将其每个切面都作为中心面,然后其余地方进行填充成(48, 256, 256)的数据进行预测
\ No newline at end of file
import pathlib
import sys
import os
current_file = pathlib.Path(__file__).resolve()
project_root = current_file.parent
project_dir_name = 'cls_train'
while project_root.name != project_dir_name and project_root != project_root.parent:
project_root = project_root.parent
if project_root.name != project_dir_name:
raise Exception(f"没有找到项目路径: {project_dir_name}")
sys.path.append(str(project_root))
import argparse
import os
import sys
from pathlib import Path
from loguru import logger # 导入 loguru
def get_logger(log_file, rotation="500 MB", compression="zip", enqueue=True):
log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <4} | [{module}].[{function}]:{line} - {message}"
logger.remove()
logger.add(log_file, format=log_format, rotation=rotation, compression=compression, enqueue=enqueue)
return logger
import pathlib
import sys
def setup_project_path(project_name='cls_train'):
cur_path = pathlib.Path(__file__).parent.resolve()
while cur_path.name != project_name:
cur_path = cur_path.parent
if cur_path.name != project_name:
raise ValueError(f"找不到项目根目录: {project_name}")
sys.path.append(str(cur_path))
return cur_path
import os
import logging
import time
import numpy as np
import scipy.ndimage
import SimpleITK as sitk
import pydicom
from cls_utils.utils import check_and_makedirs
from collections import namedtuple
NoduleBox = namedtuple('NoduleBox', ['z', 'y', 'x', 'diameter', 'uid', 'dicom_path'])
def get_cts(bundle, dicom_path):
cts = CTSeries()
start_time = time.time()
#读取dicom文件
cts.load_dicoms(dicom_path)
print('start load dicoms time {:.5f}(s)'.format(time.time() - start_time))
return cts
def get_nodule_y_pixel_length(spacing, nodule_box):
return int(np.ceil(nodule_box.diameter / spacing[1]))
def get_diameter_pixel_length(spacing, diameter):
length = np.zeros(3, np.int)
for i in range(3):
length[i] = int(np.ceil(diameter / spacing[i]))
return length
def get_nodule_rect(spacing, nodule_box):
diameter_pixel_length = get_diameter_pixel_length(spacing, nodule_box.diameter)
center = np.array([nodule_box.z, nodule_box.y, nodule_box.x])
rect = np.zeros((3, 2), np.int)
for i in range(3):
rect[i][0] = center[i] - diameter_pixel_length[i] // 2
rect[i][1] = rect[i][0] + diameter_pixel_length[i]
return rect
def resample(data, spacing, new_spacing=[1.0, 1.0, 1.0], order=1):
if data is None:
return None
new_shape = np.round(data.shape * spacing / new_spacing)
resample_spacing = spacing * data.shape / new_shape
resize_factor = new_shape / data.shape
new_data = scipy.ndimage.interpolation.zoom(data, resize_factor, mode='nearest', order=order)
return new_data, resample_spacing
def resample_nodule_box(nodule_box, spacing, new_spacing=[1.0, 1.0, 1.0]):
if nodule_box is None:
return None
z = int(np.ceil(nodule_box.z * spacing[0] / new_spacing[0]))
y = int(np.ceil(nodule_box.y * spacing[1] / new_spacing[1]))
x = int(np.ceil(nodule_box.x * spacing[2] / new_spacing[2]))
new_box = NoduleBox(int(z), int(y), int(x), nodule_box.diameter, nodule_box.uid, nodule_box.dicom_path)
return new_box
def maxip(input, level_num=2):
output = np.zeros(input.shape, input.dtype)
for i in range(len(input)):
length = level_num if i >= level_num else i
output[i] = np.max(input[i-length:i+level_num+1], axis=0)
return output
def minip(input, level_num=2):
output = np.zeros(input.shape, input.dtype)
for i in range(len(input)):
length = level_num if i >= level_num else i
output[i] = np.min(input[i-length:i+level_num+1], axis=0)
return output
class CTSeries(object):
def __init__(self):
self._SeriesInstanceUID = None
self._SOPInstanceUIDs = None
self._raw_image = None
self._raw_origin = None
self._raw_spacing = None
self._raw_direction = None
self._dicoms_is_loaded = False
def load_dicoms(self, folder_path):
logging.info('{}, Loading dicoms from {}...'.format(
time.strftime('%Y-%m-%d %H:%M:%S'), folder_path))
#print(folder_path)
dicom_names = [f for f in os.listdir(folder_path) if '.xml' not in f]
dicom_paths = list(map(lambda x: os.path.join(folder_path, x), dicom_names))
dicoms = list(map(lambda x: pydicom.read_file(x), dicom_paths))
# slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
# sort slices by their z coordinates from large to small
# idx_z_sorted = np.argsort(slice_locations)[::-1]
try:
slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
except AttributeError:
try:
slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
except AttributeError:
slice_locations = []
for i in range(len(dicoms)):
try:
slice_locations.append(float(dicoms[i].ImagePositionPatient[2]))
except AttributeError:
print(i, dicoms[i].SeriesInstanceUID)
patient_position = dicoms[0].PatientPosition
self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID
if patient_position in ['FFP', 'FFS']:
idx_z_sorted = np.argsort(slice_locations)[::-1]
else:
idx_z_sorted = np.argsort(slice_locations)[::-1]
#将dicom文件按照指定的idx_z_sorted进行重排
dicoms = list(map(lambda x: dicoms[x], idx_z_sorted))
self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))
#dicom_path_before = np.array(dicom_paths)
#print(dicom_path_before)
dicom_paths = np.array(dicom_paths)[idx_z_sorted]
#print(dicom_paths)
#print(self._SOPInstanceUIDs)
#dicoms = np.array(dicoms)[idx_z_sorted]
reader = sitk.ImageSeriesReader()
reader.SetFileNames(dicom_paths)
image_itk = reader.Execute()
# all in [z, y, x] order
self._raw_image = sitk.GetArrayFromImage(image_itk)
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_direction = image_itk.GetDirection()
# print('raw_image', self._raw_image.shape, 'raw_spacing', self._raw_spacing)
# print('raw_origin', self._raw_origin, 'raw_direction', self._raw_direction)
self._dicoms_is_loaded = True
def load_single_file(self, file_path):
logging.info('{}, Loading file from {}...'.format(
time.strftime('%Y-%m-%d %H:%M:%S'), file_path))
image_itk = sitk.ReadImage(file_path)
# all in [z, y, x] order
self._raw_image = sitk.GetArrayFromImage(image_itk)
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_direction = image_itk.GetDirection()
# print('raw_image', self._raw_image.shape, 'raw_spacing', self._raw_spacing)
# print('raw_origin', self._raw_origin, 'raw_direction', self._raw_direction)
self._dicoms_is_loaded = True
def get_raw_image(self):
return self._raw_image
def set_raw_image(self, data):
self._raw_image = data
def get_raw_origin(self):
return self._raw_origin
def get_raw_spacing(self):
return self._raw_spacing
def get_raw_image_affine(self):
affine = np.diag(list(self._raw_spacing) + [1])
affine[:, 3][:3] = np.array(list(self._raw_origin))
return affine
def save_raw_image(self, output_file):
if self._dicoms_is_loaded:
self.save_image(self._raw_image, output_file)
def save_image(self, data, output_file):
if self._dicoms_is_loaded:
image = sitk.GetImageFromArray(data)
image.SetSpacing(np.array(list(reversed(self._raw_spacing))))
image.SetOrigin(np.array(list(reversed(self._raw_origin))))
image.SetDirection(self._raw_direction)
check_and_makedirs(output_file)
sitk.WriteImage(image, output_file)
def transform_file_type(self, input_file, output_file):
image = sitk.ReadImage(input_file)
sitk.WriteImage(image, output_file)
# -*- coding: utf-8 -*-
import os
import time
import logging
import numpy as np
import SimpleITK as sitk
import pydicom
import traceback
from concurrent.futures import ThreadPoolExecutor
from multiprocessing.pool import Pool
from testing.utils.box_utils import NoduleBox, nodule_raw2standard
from testing.utils.data_utils import check_and_makedirs
from testing.utils.data_utils import clip_data, downsample_data, upsample_mask, resample_data, resample_mask
def lung_segment(data_npy):
def morphology_opening_2d(mask, data_npy, pool_size=5):
mask_npy = np.zeros(data_npy.shape, data_npy.dtype)
chest_mask_npy = np.zeros(data_npy.shape, data_npy.dtype)
def morphology_opening_2d_task(i, mask_2d):
mask_slice = sitk.BinaryMorphologicalOpening(mask_2d, 2)
mask_npy[i] = sitk.GetArrayFromImage(mask_slice)
chest_mask_npy[i] = sitk.GetArrayFromImage(sitk.BinaryFillhole(mask_slice))
with ThreadPoolExecutor(pool_size) as executor:
for i in range(data_npy.shape[0]):
executor.submit(morphology_opening_2d_task, i, mask[:, :, i])
return sitk.GetImageFromArray(mask_npy), sitk.GetImageFromArray(chest_mask_npy)
def morphology_closing_2d(mask, data_npy, pool_size=5):
mask_npy = np.zeros(data_npy.shape, data_npy.dtype)
def morphology_closing_2d_task(i, mask_2d):
mask_slice = sitk.BinaryMorphologicalClosing(mask_2d, 15)
mask_slice = sitk.BinaryDilate(mask_slice, 15)
mask_npy[i] = sitk.GetArrayFromImage(sitk.BinaryFillhole(mask_slice))
mask_npy[i] = np.bitwise_or(mask_npy[i], mask_npy[i][:, ::-1])
with ThreadPoolExecutor(pool_size) as executor:
for i in range(data_npy.shape[0]):
executor.submit(morphology_closing_2d_task, i, mask[:, :, i])
return sitk.GetImageFromArray(mask_npy)
data_npy = data_npy.astype(np.int32)
data_npy = clip_data(data_npy)
data = sitk.GetImageFromArray(data_npy)
mask = 1 - sitk.OtsuThreshold(data)
mask, chest_mask = morphology_opening_2d(mask, data_npy)
lung_mask = sitk.Subtract(chest_mask, mask)
# Remove areas not in the chest, when CT covers regions below the chest
eroded_mask = sitk.BinaryErode(lung_mask, 10)
seed_npy = sitk.GetArrayFromImage(eroded_mask)
seed_npy = np.array(seed_npy.nonzero())[[2, 1, 0]]
seeds = seed_npy.T.tolist()
lung_mask = sitk.ConfidenceConnected(lung_mask, seeds, multiplier=2.5)
lung_mask = morphology_closing_2d(lung_mask, data_npy)
return sitk.GetArrayFromImage(lung_mask)
def lung_segment_enhance(data):
z_max = max(int(np.ceil(data.shape[0] / 100)), 2)
if data.shape[1] <= 512 and data.shape[2] <= 512:
scale = (z_max, 2, 2)
else:
scale = (z_max, 4, 4)
new_data_shape = np.array(data.shape) // scale * scale
new_data = data[:new_data_shape[0], :new_data_shape[1], :new_data_shape[2]]
new_data = downsample_data(new_data, scale)
new_data = new_data.astype(data.dtype)
new_mask = lung_segment(new_data)
mask = upsample_mask(new_mask, scale)
pad = np.zeros((3, 2), np.int32)
pad[:, 1] = np.array(data.shape) - np.array(mask.shape)
mask = np.pad(mask, pad, mode='edge')
return mask
def get_lung_mask_and_box(data, uid, segment=False, segment_data=False, segment_margin=[0, 0, 0], min_points=10000):
lung_box = np.zeros((3, 2), np.int32)
lung_box[:, 1] = data.shape
found_lung = False
if segment:
try:
start_time = time.time()
lung_mask = lung_segment_enhance(data)
logging.info('{}, {}, Lung segment run time {:.2f}(s)'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), uid, (time.time() - start_time)))
if np.sum(lung_mask == 1) > min_points:
found_lung = True
if found_lung and segment_data:
segment_margin = np.array(segment_margin)
coords = np.asarray(np.where(lung_mask == 1))
lung_box[:, 0] = np.maximum(coords.min(axis=1) - segment_margin, 0)
lung_box[:, 1] = np.minimum(coords.max(axis=1) + segment_margin + 1, data.shape)
data = data[lung_box[0, 0]:lung_box[0, 1],
lung_box[1, 0]:lung_box[1, 1],
lung_box[2, 0]:lung_box[2, 1]]
lung_mask = lung_mask[lung_box[0, 0]:lung_box[0, 1],
lung_box[1, 0]:lung_box[1, 1],
lung_box[2, 0]:lung_box[2, 1]]
except Exception as e:
traceback.print_exc()
if not found_lung:
lung_mask = np.ones(data.shape, np.uint8)
return data, lung_mask, lung_box
def transform_file_type(input_file, output_file):
image = sitk.ReadImage(input_file)
sitk.WriteImage(image, output_file)
def load_single_dicom(input_file):
image = sitk.ReadImage(input_file)
slice = sitk.GetArrayFromImage(image)[0, :, :]
return slice
class CTSeries(object):
def __init__(self):
self._PatientID = None
self._SeriesInstanceUID = None
self._SOPInstanceUIDs = None
self._ReconstructionDiameter = None
self._Rows = None
self._Columns = None
self._AcquisitionDate = None
self._Manufacturer = None
self._InstitutionName = None
self._raw_data = None
self._raw_spacing = None
self._raw_origin = None
self._raw_direction = None
self._dicoms_is_loaded = False
self._lung_mask = None
self._lung_box = None
self._standard_data = None
self._standard_spacing = None
self._dicoms_is_preprocessed = False
self._raw_labels = None
self._standard_labels = None
self._label_is_loaded = False
self._standard_is_loaded = False
def load_dicoms(self, dicom_dir_path):
logging.info('{}, Loading dicoms from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), dicom_dir_path))
dicom_names = [f for f in os.listdir(dicom_dir_path) if '.xml' not in f]
dicom_paths = list(map(lambda x: os.path.join(dicom_dir_path, x), dicom_names))
dicoms = list(map(lambda x: pydicom.read_file(x, stop_before_pixels=True), dicom_paths))
try:
slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
except AttributeError:
slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
# sort slices by their z coordinates from large to small
if dicoms[0].get('PatientPosition') is None:
patient_position = 'HFS'
else:
patient_position = dicoms[0].PatientPosition
if patient_position in ['FFP', 'FFS']:
idx_z_sorted = np.argsort(slice_locations)[::-1]
else:
idx_z_sorted = np.argsort(slice_locations)[::-1]
dicom_paths = np.asarray(dicom_paths)[idx_z_sorted]
self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID
self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))[idx_z_sorted]
try:
self._PatientID = dicoms[0].PatientID
self._ReconstructionDiameter = dicoms[0].ReconstructionDiameter
self._Rows = dicoms[0].Rows
self._Columns = dicoms[0].Columns
self._AcquisitionDate = dicoms[0].AcquisitionDate
self._Manufacturer = dicoms[0].Manufacturer
self._InstitutionName = dicoms[0].InstitutionName
except AttributeError:
self._PatientID = None
self._ReconstructionDiameter = None
self._Rows = None
self._Columns = None
self._AcquisitionDate = None
self._Manufacturer = None
self._InstitutionName = None
reader = sitk.ImageSeriesReader()
reader.SetFileNames(dicom_paths)
image_itk = reader.Execute()
# all in [z, y, x] order
self._raw_data = sitk.GetArrayFromImage(image_itk)
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_direction = image_itk.GetDirection()
self._dicoms_is_loaded = True
def load_dicoms_mp(self, dicom_dir_path):
logging.info('{}, Loading dicoms from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), dicom_dir_path))
dicom_names = [f for f in os.listdir(dicom_dir_path) if '.xml' not in f]
dicom_paths = list(map(lambda x: os.path.join(dicom_dir_path, x), dicom_names))
dicoms = list(map(lambda x: pydicom.read_file(x, stop_before_pixels=True), dicom_paths))
try:
slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
except AttributeError:
slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
# sort slices by their z coordinates from large to small
if dicoms[0].get('PatientPosition') is None:
patient_position = 'HFS'
else:
patient_position = dicoms[0].PatientPosition
if patient_position in ['FFP', 'FFS']:
idx_z_sorted = np.argsort(slice_locations)[::-1]
else:
idx_z_sorted = np.argsort(slice_locations)[::-1]
dicom_paths = np.asarray(dicom_paths)[idx_z_sorted]
self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID
self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))[idx_z_sorted]
try:
self._PatientID = dicoms[0].PatientID
self._ReconstructionDiameter = dicoms[0].ReconstructionDiameter
self._Rows = dicoms[0].Rows
self._Columns = dicoms[0].Columns
self._AcquisitionDate = dicoms[0].AcquisitionDate
self._Manufacturer = dicoms[0].Manufacturer
self._InstitutionName = dicoms[0].InstitutionName
except AttributeError:
self._PatientID = None
self._ReconstructionDiameter = None
self._Rows = None
self._Columns = None
self._AcquisitionDate = None
self._Manufacturer = None
self._InstitutionName = None
pool = Pool(processes=20)
slice_list = pool.map(load_single_dicom, dicom_paths)
pool.close()
pool.join()
raw_data = np.zeros((len(dicom_paths), dicoms[0].Rows, dicoms[0].Columns))
for idx, slice_data in enumerate(slice_list):
raw_data[idx] = slice_data
image0 = sitk.ReadImage(dicom_paths[0])
dicom0 = pydicom.read_file(dicom_paths[0], stop_before_pixels=True)
# all in [z, y, x] order
self._raw_data = raw_data
self._raw_spacing = np.array(list(reversed(image0.GetSpacing())))
self._raw_origin = np.array(list(reversed(dicom0.ImagePositionPatient)))
self._raw_direction = image0.GetDirection()
self._dicoms_is_loaded = True
def load_single_file(self, file_path, uid=None):
logging.info('{}, Loading file from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_path))
self._SeriesInstanceUID = uid
if self._SeriesInstanceUID is None:
self._SeriesInstanceUID = os.path.splitext(os.path.basename(file_path))[0]
image_itk = sitk.ReadImage(file_path)
# all in [z, y, x] order
self._raw_data = sitk.GetArrayFromImage(image_itk)
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_direction = image_itk.GetDirection()
self._dicoms_is_loaded = True
def load_file(self, file_path, uid=None):
if os.path.splitext(os.path.basename(file_path))[-1] in ['.mhd']:
self.load_single_file(file_path, uid)
else:
self.load_dicoms(file_path)
def is_hrct(self):
"""
是否是高清扫描
"""
if self._Rows is not None and self._Rows >= 1024:
return True
return False
def is_ultra_hrct(self, max_reconstruction_diameter=250):
"""
是否是高清靶扫描
"""
if self._Rows is not None and self._Rows >= 1024 and \
self._ReconstructionDiameter is not None and \
self._ReconstructionDiameter < max_reconstruction_diameter:
return True
return False
def get_patient_id(self):
return self._PatientID
def get_series_instance_uid(self):
return self._SeriesInstanceUID
def get_reconstruction_diameter(self):
return self._ReconstructionDiameter
def get_rows(self):
return self._Rows
def get_columns(self):
return self._Columns
def get_acquisition_date(self):
return self._AcquisitionDate
def get_manufacturer(self):
return self._Manufacturer
def get_institution_name(self):
return self._InstitutionName
def get_raw_data(self):
return self._raw_data
def get_raw_spacing(self):
return self._raw_spacing
def get_raw_origin(self):
return self._raw_origin
def get_raw_direction(self):
return self._raw_direction
def get_lung_mask(self):
return self._lung_mask
def get_lung_box(self):
return self._lung_box
def get_standard_data(self):
return self._standard_data
def get_standard_spacing(self):
return self._standard_spacing
def get_raw_labels(self):
return self._raw_labels
def get_standard_labels(self):
return self._standard_labels
def save_raw_data(self, output_file):
if self._dicoms_is_loaded:
image = sitk.GetImageFromArray(self._raw_data)
image.SetSpacing(np.array(list(reversed(self._raw_spacing))))
image.SetOrigin(np.array(list(reversed(self._raw_origin))))
image.SetDirection(self._raw_direction)
check_and_makedirs(output_file, is_file=True)
sitk.WriteImage(image, output_file)
def world_to_voxel_coord(self, world_coord):
voxel_coord = np.absolute(world_coord - self._raw_origin) / self._raw_spacing
return voxel_coord
def voxel_to_world_coord(self, voxel_coord):
world_coord = voxel_coord * self._raw_spacing + self._raw_origin
return world_coord
def preprocess(self, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing ...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def preprocess_1024u(self, check_spacing=False, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing 1024u...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
if check_spacing and \
(float(self._raw_spacing[0]) > 1.5 or self._ReconstructionDiameter is None or
float(self._ReconstructionDiameter) > 190):
logging.info('{}, {}, Preprocessing 1024u resample data...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
# z轴间隔大于1.5mm或者FOV为空或者大于190mm,重采样z轴间隔为1mm与FOV为180mm
'''
普通1024分辨率CT
# 原始spacing: [1.25, 0.3515625, 0.3515625] mm
# 原始shape: [200, 1024, 1024]
# 如果 z轴spacing > 1.5mm:
# z轴重采样到1mm
new_spacing = [1.0, 0.3515625, 0.3515625]
# 否则保持原始spacing
# 高清靶扫描1024分辨率CT
原始spacing: [1.80, 0.17578125, 0.17578125] mm
# 原始shape: [200, 1024, 1024]
# ReconstructionDiameter: 180mm (FOV较小,靶向扫描)
# 如果满足以下任一条件:
# - z轴spacing > 1.5mm
# - ReconstructionDiameter为空
# - ReconstructionDiameter > 190mm
new_spacing = [1.0, 0.17578125, 0.17578125] # z轴重采样到1mm
# 否则保持原始spacing
普通1024CT的像素spacing约为0.35mm
高清靶扫描1024CT的像素spacing约为0.18mm,分辨率更高
高清靶扫描通常FOV较小(≤190mm),用于局部精细扫描
两种情况下z轴spacing>1.5mm时都会重采样到1mm
高清靶扫描对FOV大小有额外检查,超过190mm会触发重采样
'''
new_spacing = np.array([1.0, 0.17578125, 0.17578125]) * np.array(scale)
standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing)
else:
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def preprocess_1024(self, check_spacing=False, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing 1024...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
if check_spacing and \
float(self._raw_spacing[0]) > 1.5:
logging.info('{}, {}, Preprocessing 1024 resample data...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
# z轴间隔大于1.5mm,重采样z轴间隔为1mm
new_spacing = np.array([1.0, 0.3515625, 0.3515625]) * np.array(scale)
standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing)
else:
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def preprocess_512(self, check_spacing=False, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing 512...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
if check_spacing and \
float(self._raw_spacing[0]) > 1.5:
logging.info('{}, {}, Preprocessing 512 resample data...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
# z轴间隔大于1.5mm,重采样z轴间隔为1mm
new_spacing = np.array([1.0, 0.703125, 0.703125]) * np.array(scale)
standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing)
else:
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def load_labels(self, label_path):
"""
Load labels from label_path.
label_path: path to the label file, which is a csv with 5 fields:
[z, y, x, diameter, is_pos]
"""
if not self._dicoms_is_loaded:
raise Exception('DICOM files have not been loaded yet')
if not self._dicoms_is_preprocessed:
raise Exception('DICOM files have not been preprocessed yet')
logging.info('{}, {}, Loading labels from {}...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID, label_path))
nodule_boxes = []
with open(label_path) as f:
for line in f:
z, y, x, diameter, is_pos = \
line.strip('\n').replace('"', '').split(',')[0:5]
nodule_boxes.append(
NoduleBox(float(z), float(y), float(x), float(diameter), float(is_pos), 1.0))
self._raw_labels = nodule_boxes
self._standard_labels = nodule_raw2standard(
self._raw_labels, self._raw_spacing, self._standard_spacing, start=self._lung_box[:, 0])
self._label_is_loaded = True
def load_nodule_boxes(self, nodule_boxes):
"""
Load labels from nodule_boxes.
"""
if not self._dicoms_is_loaded:
raise Exception('DICOM files have not been loaded yet')
if not self._dicoms_is_preprocessed:
raise Exception('DICOM files have not been preprocessed yet')
self._raw_labels = nodule_boxes
self._standard_labels = nodule_raw2standard(
self._raw_labels, self._raw_spacing, self._standard_spacing, start=self._lung_box[:, 0])
self._label_is_loaded = True
def save_standard_npy(self, npy_path, uid):
"""
Save *_standard data in numpy format.
npy_path: str, path to save.
uid: str, prefix for the files.
"""
if not self._dicoms_is_preprocessed:
raise Exception('DICOM files have not been preprocessed yet')
check_and_makedirs(npy_path)
np.save(os.path.join(npy_path, uid + '_standard_data.npy'),
self._standard_data.astype(np.float32))
np.save(os.path.join(npy_path, uid + '_standard_spacing.npy'),
self._standard_spacing.astype(np.float32))
# 检查标注是否错误
standard_data_shape = np.array(self._standard_data.shape)
standard_labels = []
for i in range(len(self._standard_labels)):
nodule_box = self._standard_labels[i]
if not (0 <= nodule_box.z <= standard_data_shape[0] - 1 and
0 <= nodule_box.y <= standard_data_shape[1] - 1 and
0 <= nodule_box.x <= standard_data_shape[2] - 1):
logging.error('{}, {}, label index={} error.'.format(time.strftime("%Y-%m-%d %H:%M:%S"), uid, i))
standard_labels.append(np.array(self._standard_labels[i]))
standard_labels = np.array(standard_labels)
np.save(os.path.join(npy_path, uid + '_standard_labels.npy'),
standard_labels.astype(np.float32))
def load_standard_npy(self, npy_path, uid, mmap_mode='r'):
"""
Load *_standard data in numpy format.
npy_path: str, path to load.
uid: str, prefix for the files.
"""
if self._dicoms_is_preprocessed:
raise Exception('DICOM files have already been preprocessed')
self._SeriesInstanceUID = uid
self._standard_data = np.load(
os.path.join(npy_path, uid + '_standard_data.npy'), mmap_mode=mmap_mode)
self._standard_spacing = np.load(
os.path.join(npy_path, uid + '_standard_spacing.npy'))
self._standard_labels = []
standard_labels = np.load(
os.path.join(npy_path, uid + '_standard_labels.npy'))
for i in range(len(standard_labels)):
z, y, x, diameter, is_pos = standard_labels[i][0:5]
self._standard_labels.append(
NoduleBox(z, y, x, diameter, is_pos, 1.0))
self._standard_is_loaded = True
def save_standard_mask_npy(self, npy_path, uid, nodule_index, mask):
"""
保存mask
"""
check_and_makedirs(npy_path)
standard_mask = mask
if np.any(self._standard_spacing != self._raw_spacing):
standard_mask = resample_mask(mask, self._standard_data.shape)
lung_box = self._lung_box
standard_mask = standard_mask[lung_box[0, 0]:lung_box[0, 1],
lung_box[1, 0]:lung_box[1, 1],
lung_box[2, 0]:lung_box[2, 1]]
np.save(os.path.join(npy_path, uid + '_' + str(nodule_index) + '_standard_mask.npy'),
standard_mask.astype(np.uint8))
def load_standard_mask_npy(self, npy_path, uid, nodule_index, mmap_mode='r'):
"""
加载mask
"""
standard_mask = np.load(
os.path.join(npy_path, uid + '_' + str(nodule_index) + '_standard_mask.npy'), mmap_mode=mmap_mode)
return standard_mask
import os
import time
from binascii import b2a_hex
import cv2
import numpy as np
#import nibabel as nib
import base64
from Crypto.Cipher import AES
def encrypt(text):
key = '1234561234561234'
cryptor = AES.new(key, AES.MODE_CBC, key)
count = len(text)
add = len(key) - (count % len(key))
text = text + ('\0' * add)
ciphertext = cryptor.encrypt(text)
return b2a_hex(ciphertext)
def decrypt_oralce(text):
key = '123456'
aes = AES.new(add_to_16(key), AES.MODE_ECB)
base64_decrypted = base64.decodebytes(text.encode(encoding='utf-8'))
decrypted_text = str(aes.decrypt(base64_decrypted), encoding='utf-8').replace('\0', '')
return decrypted_text
def add_to_16(value):
while len(value) % 16 != 0:
value += '\0'
return str.encode(value) # 返回bytesl
def index_to_list(indexs):
if indexs:
strList = indexs[:-1].split('],')
list = []
for idx, str in enumerate(strList):
stri = str[1:].split(',')
map = {'x': stri[0], 'y': stri[1]}
list.append(map)
return list
def index_to_meta(indexs):
list = index_to_list(indexs)
x = min_by_key(list, 'x')
y = min_by_key(list, 'y')
w = max_by_key(list, 'x') - min_by_key(list, 'x') + 1
h = max_by_key(list, 'y') - min_by_key(list, 'y') + 1
spot = ''.zfill(w * h)
spot_len = len(spot)
for m in list:
flag = (int(m['y']) - y) * w + (int(m['x']) - x)
spot = spot[0:flag] + '1' + spot[flag+1:spot_len]
strs = get_str_list(spot, 32)
array = '['
for s in strs:
array = array + '%d,' % (parse_unsigned_int(s, 2))
array = array[:-1] + ']'
meta = '{"x":%s,"y":%s,"w":%s,"h":%s,"delineation":%s}' % (x, y, w, h, array)
return meta
def parse_unsigned_int(s, radix):
length = len(s)
if length > 0:
firstChar = int(s[0])
if firstChar > 0:
newstr=''
for str in s[1:]:
if str == '0':
newstr = newstr + '1'
else:
newstr = newstr + '0'
return -(int(newstr, radix) + 1)
else:
return int(s, radix)
def get_str_list(instr, length):
size = len(instr) / length
if len(instr) % length != 0:
s = '00000000000000000000000000000000000'
instr = instr + s[len(instr) % length:]
size += 1
return get_str_list_size(instr, length, size)
def get_str_list_size(instr, length, size):
list = []
for i in range(int(size)):
childStr = substring(instr, i * length, (i + 1) * length)
list.append(childStr)
return list
def substring(str, f, t):
if f > len(str):
return None
if t > len(str):
return str[f:len(str)]
else:
return str[f:t]
def sub(str, p, c):
newstr = []
for s in str:
newstr.append(s)
newstr[p] = c
return ''.join(newstr)
def min_by_key(coll, key):
if coll:
candidate = coll[0][key]
for map in coll:
if int(map[key]) < int(candidate):
candidate = map[key]
return int(candidate)
def max_by_key(coll, key):
if coll:
candidate = coll[0][key]
for map in coll:
if int(map[key]) > int(candidate):
candidate = map[key]
return int(candidate)
def sigmoid(X):
return 1 / (1 + np.exp(-X))
'''
def load_image(input_file):
return nib.load(input_file)
def get_image(data, affine, nib_class=nib.Nifti1Image):
return nib_class(dataobj=data, affine=affine)
'''
def hu_value_clip(image, hu_min=-1200.0, hu_max=600.0, hu_nan=-2000.0):
image_new = np.asarray(image)
image_new[np.isnan(image_new)] = hu_nan
image_new = np.clip(image_new, hu_min, hu_max)
return image_new
#对ct原像素进行处理,将其像素值进行规范化
def hu_value_to_uint8(original_data, hu_min=-1200.0, hu_max=600.0, hu_nan=-2000.0):
image_new = np.asarray(original_data)
image_new[np.isnan(image_new)] = hu_nan
# normalize to [0, 1]
image_new = (image_new - hu_min)
image_new = image_new / (hu_max - hu_min)
image_new = np.clip(image_new, 0, 1)
image_new = (image_new * 255).astype(np.float32)
return image_new
#将数据转化到(-1,1)
def normalize(image):
image_new = np.asarray(image)
image_new = (image_new - 128.0) / 128.0
return image_new
def hu_normalize(original_data, hu_nan=-1000):
image_new = np.asarray(original_data)
image_new[np.isnan(image_new)] = hu_nan
image_new = image_new / 1000
image_new = np.clip(image_new, -1, 1)
image_new = image_new.astype(np.float32)
return image_new
def check_and_makedirs(output_file, is_file=True):
if is_file:
output_dir = os.path.dirname(output_file)
else:
output_dir = output_file
#如果文件夹下面存在文件,则将其全部删除
if os.path.exists(output_file):
os.remove(output_file)
if output_dir is not None and output_dir != '' and not os.path.exists(output_dir):
os.makedirs(output_dir)
def get_z_start_end_from_bound(data, z_center, bound_size):
z_start, z_end = max(z_center - bound_size, 0), min(z_center + bound_size, data.shape[0])
return z_start, z_end
def get_crop_start_end(start, end, crop_size):
center = (start + end) // 2
crop_start = center - np.asarray(crop_size) // 2
crop_start = np.maximum(crop_start, 0)
return crop_start, crop_start + crop_size
def print_data_mask(data, mask, filename=None, group=False):
coords = np.asarray(np.where(mask > 0))
if len(coords[0]) > 0:
new_start = coords.min(axis=1)
new_end = coords.max(axis=1) + 1
print_mask = mask[new_start[0]:new_end[0], new_start[1]:new_end[1]]
print_data = data[new_start[0]:new_end[0], new_start[1]:new_end[1]]
if filename is not None:
with open(filename, 'w') as f:
for temp_data in print_data:
f.write(np.array2string(temp_data).replace('\n', '') + '\n')
for temp_mask in print_mask:
f.write(np.array2string(temp_mask).replace('\n', '') + '\n')
if group:
for temp_data, temp_mask in zip(print_data, print_mask):
temp_group = np.array([str(i) + '(' + str(j) + ')' for i, j in zip(temp_data, temp_mask)])
f.write(np.array2string(temp_group).replace('\n', '').replace('\'', '') + '\n')
else:
for temp_data in print_data:
print(np.array2string(temp_data).replace('\n', '') + '\n')
for temp_mask in print_mask:
print(np.array2string(temp_mask).replace('\n', '') + '\n')
if group:
for temp_data, temp_mask in zip(print_data, print_mask):
temp_group = np.array([str(i) + '(' + str(j) + ')' for i, j in zip(temp_data, temp_mask)])
print(np.array2string(temp_group).replace('\n', '').replace('\'', '') + '\n')
def base64_to_list(base64_str):
indexs = ''
list = []
img_np = None
if base64_str:
time_now = time.time()
img_data = base64.b64decode(base64_str[22:])
nparr = np.fromstring(img_data, np.uint8)
img_np = cv2.imdecode(nparr, 0)
img_np[img_np != 0] = 1
point_list = np.where(img_np != 0)
if len(point_list[0]) > 0:
y_list = point_list[0]
x_list = point_list[1]
indexs = '{'
for point_idx, x in enumerate(x_list):
raw_x = x
raw_y = y_list[point_idx]
map = {'x': raw_x, 'y': raw_y}
list.append(map)
indexs = indexs + '[%s,%s],' % (raw_x, raw_y)
indexs = indexs[:-1] + '}'
#print('run time {:.5f}(s)'.format(time.time() - time_now))
return list, indexs, img_np
\ No newline at end of file
[conf]
MODEL_PATH = D:/plus_new/python/checkpoint/
ZIP_DEMO = D:/plus_new/python/service/demo.txt
AI_PATH_PREFIX = D:/plus_new/ai
HTTP_SERVER_PREFIX = http://127.0.0.1/dcm
KAFKA_SERVER = 127.0.0.1:9092
MYSQL_SERVER = mysql+pymysql://lung:lung1qaz2wsx@127.0.0.1:3306/ct_file?charset=utf8
DICOM_RAW_DATA_PATH = D:/plus_new/raw-data
REDIS_SYNC_KEY = sync-server
REDIS_SERVER = 127.0.0.1
REDIS_DB = 5
\ No newline at end of file
import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from configparser import ConfigParser
import json
#from testing.get_model import get_all_model
"""parser = argparse.ArgumentParser(description='load env')
parser.add_argument('--profile', default='dev', type=str,
help='profile information')
args = parser.parse_args()
profile = args.profile"""
ws_cfg = ConfigParser()
ws_cfg.read(os.path.dirname(os.path.abspath(__file__)) + '/app-qa.properties')
"""if profile == 'qa':
ws_cfg.read(os.path.dirname(os.path.abspath(__file__)) + '/app-qa.properties')
else:
ws_cfg.read((os.path.dirname(os.path.abspath(__file__)) + '/app-dev.properties'))"""
# global var
MYSQL_SERVER = str(ws_cfg.get('conf', 'MYSQL_SERVER'))
AI_PATH_PREFIX = str(ws_cfg.get('conf', 'AI_PATH_PREFIX'))
HTTP_SERVER_PREFIX = str(ws_cfg.get('conf', 'HTTP_SERVER_PREFIX'))
\ No newline at end of file
{
"pretrain_ckpt": "./cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2031/cls_234567_203120240815-0048.ckpt",
"csv_path" : "/home/lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_2031/train_original.csv",
"node_times" : [[2031], [2011,2021,2041]],
"pretrain_folder": "./cls_train/cls_ckpt/best_cls_replenish/10/cls_1234567_1010-1016-1020",
"dicom_folder": "/opt/lung/ai" ,
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256],
"validation_filename": "cls_train/log/validation/cls_234567_2031/20240815/log_validation_20240815.log",
"compute_accuracy_filename": "cls_train/log/accuracy/cls_234567_2031/accuracy.log"
}
\ No newline at end of file
{
"pretrain_ckpt": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2041/cls_1_204120230927-1256.ckpt",
"csv_path" : "/df_lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_2041/train.csv",
"node_times" : [[2041]],
"pretrain_folder": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_replenish/10/cls_1234567_1010-1016-1020",
"dicom_folder": "/opt/lung/ai" ,
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256],
"validation_filename": "/df_lung/ai-project/cls_train/log/validation/cls_234567_2041/20241106/log_validation_20241106.log",
"compute_accuracy_filename": "/df_lung/ai-project/cls_train/log/accuracy/cls_234567_2041/20241106/accuracy.log"
}
\ No newline at end of file
{
"pretrain_ckpt": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2041/cls_234567_204120240105-0248.ckpt",
"csv_path" : "/df_lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_2041/train.csv",
"node_times" : [[2041]],
"pretrain_folder": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_replenish/10/cls_1234567_1010-1016-1020",
"dicom_folder": "/opt/lung/ai" ,
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256],
"validation_filename": "/df_lung/ai-project/cls_train/log/validation/cls_234567_2041/20241112/log_validation_20241112.log",
"compute_accuracy_filename": "/df_lung/ai-project/cls_train/log/accuracy/cls_234567_2041/20241112/accuracy.log"
}
\ No newline at end of file
{
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256] ,
"dicom_folder": "/opt/lung/ai" ,
"train_data_path": "./cls_train/data/train_data/plus_3d_0818" ,
"npy_folder": "npy_data",
"csv_path": "subject_all_csv" ,
"subject_all_csv": "subject_all.csv",
"image_path": "train_image",
"n_channels": 1 ,
"n_diff_classes": 1 ,
"n_base_filters": 16 ,
"batch_size": 8,
"lr": 1e-4,
"momentum": 0.9,
"epoch": 1000 ,
"patience": 200 ,
"learning_rate_drop": 0.1,
"early_stop": 1000 ,
"train_csv_file": "" ,
"ckpt_save_path": "./cls_train/cls_ckpt/best_cls_3d_1025/" ,
"ckpt_pretrain_path": "./cls_train/best_cls_3d_1_5001-6001_0824/temp",
"ckpt_file": "cls_1_5001-6001/cls_1_5001-600120230825-0208.ckpt",
"training_filename": "cls_train/log/train/log_training_3d_234567_2031_20240725.log"
}
{
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256] ,
"dicom_folder": "/opt/lung/ai" ,
"train_data_path": "./cls_train/data/train_data/plus_3d_1104" ,
"npy_folder": "npy_data",
"csv_path": "subject_all_csv" ,
"subject_all_csv": "subject_all.csv",
"image_path": "train_image",
"n_channels": 1 ,
"n_diff_classes": 1 ,
"n_base_filters": 16 ,
"batch_size": 8,
"lr": 1e-4,
"momentum": 0.9,
"epoch": 1000 ,
"patience": 200 ,
"learning_rate_drop": 0.1,
"early_stop": 1000 ,
"train_csv_file": "" ,
"ckpt_save_path": "./cls_train/cls_ckpt/best_cls_3d_1104/" ,
"ckpt_pretrain_path": "./cls_train/best_cls_3d_1_5001-6001_0824/temp",
"ckpt_file": "cls_1_5001-6001/cls_1_5001-600120230825-0208.ckpt",
"training_filename": "cls_train/log/train/log_training_3d_234567_2031_20241104.log"
}
{
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256] ,
"dicom_folder": "/opt/lung/ai" ,
"train_data_path": "data/train_data/plus_3d_0818" ,
"npy_folder": "npy_data",
"csv_path": "subject_all_csv" ,
"subject_all_csv": "subject_all.csv",
"image_path": "train_image",
"n_channels": 1 ,
"n_diff_classes": 1 ,
"n_base_filters": 16 ,
"batch_size": 8,
"lr": 1e-4,
"momentum": 0.9,
"epoch": 5 ,
"patience": 200 ,
"learning_rate_drop": 0.1,
"early_stop": 1000 ,
"train_csv_file": "" ,
"ckpt_save_path": "./cls_ckpt/best_cls_3d_20241112_2041/" ,
"ckpt_pretrain_path": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2041/cls_234567_204120240105-0248.ckpt",
"ckpt_file": "cls_1_5001-6001/cls_1_5001-600120230825-0208.ckpt",
"training_filename": "./log/train/log_training_3d_234567_2041_20241112.log"
}
# -*- coding: utf-8 -*-
import os, sys
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
import numpy as np
import logging
import time
import traceback
import itertools
import mahotas
import cv2
from scipy import ndimage
from skimage import measure
from skimage import morphology
from multiprocessing.pool import Pool
from data.data_process_utils.test_data_utils import get_crop_start_end
from data.data_process_utils.test_data_utils import filter_data
from data.data_process_utils.test_image_vis import plot_shrink_extend_mask
def get_region_data(data, point, kernel_r):
region_data = data[point[0] - kernel_r:point[0] + kernel_r + 1, point[1] - kernel_r:point[1] + kernel_r + 1]
return region_data
def get_contour_points(contour):
contour = contour.copy()
contour = contour.reshape((-1, 2))
contour = np.flip(contour, axis=1)
return contour
def get_around_points(mask, point, increase_points, select_value=-1):
around_points = point + increase_points
around_points[:, 0] = np.clip(around_points[:, 0], 0, mask.shape[0] - 1)
around_points[:, 1] = np.clip(around_points[:, 1], 0, mask.shape[1] - 1)
return select_around_points(mask, around_points, select_value)
def batch_get_around_points(mask, contour, increase_points, select_value=-1):
contour = get_contour_points(contour)
contour_point = contour.reshape((-1, 1, 2))
contour_point = contour_point + increase_points
around_points = contour_point.reshape((-1, 2))
around_points[:, 0] = np.clip(around_points[:, 0], 0, mask.shape[0] - 1)
around_points[:, 1] = np.clip(around_points[:, 1], 0, mask.shape[1] - 1)
return select_around_points(mask, around_points, select_value)
def select_around_points(mask, around_points, select_value=-1):
if select_value >= 0:
around_mask = mask[around_points[:, 0], around_points[:, 1]]
# 只选择mask为select_value值的点
return around_points[around_mask == select_value]
return around_points
def calculate_increase_points(length, include_center_point=True):
length = int(length)
increase_length = max(length * 2 + 1, 3)
increase_points = list(itertools.product(range(increase_length), repeat=2))
if not include_center_point:
increase_points = [point for point in increase_points if point != (length, length)]
increase_points = np.array(increase_points) - length
return increase_points
# 密度允许的误差范围
density_allow_error_range = 30
density_reject_error_range = 80
increase_points2 = calculate_increase_points(2)
increase_points1 = calculate_increase_points(1)
kernel3 = np.array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], np.uint8)
kernel5 = np.array([[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[1, 1, 1, 1, 1],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0]], np.uint8)
kernel7 = np.array([[0, 0, 0, 1, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 1, 0, 0, 0]], np.uint8)
def get_diameter_pixel(diameter, spacing, default_pixel=3):
"""
计算某直径对应的像数点数
"""
return max(np.round(diameter / spacing[1], 2), default_pixel)
def batch_shrink_extend_edge(data, spacing, uid, nodule_mask_list,
amend_head_tail=False, gen_image=False, image_path=''):
"""
批量修正边缘
"""
for nodule_index, one_nodule_mask in enumerate(nodule_mask_list):
start_time = time.time()
ret, result_mask, extend_mask, shrink_mask = \
shrink_extend_edge(data,
one_nodule_mask,
spacing,
amend_head_tail=amend_head_tail,
file_prefix=uid + '_' + str(nodule_index),
log=True)
print(uid + '_' + str(nodule_index), 'amend edge run time {:.2f}(s)'.format(time.time() - start_time))
if gen_image:
file_path = os.path.join(image_path, uid + '_' + str(nodule_index) + '_result.png')
plot_shrink_extend_mask(data, one_nodule_mask, result_mask, extend_mask, shrink_mask, spacing, file_path)
nodule_mask_list[nodule_index] = result_mask
def shrink_extend_edge(data, mask, spacing,
amend_head_tail=True, shrink_extend_mask=True, min_pixel=5,
file_prefix='', log=False):
"""
修正边缘
"""
try:
return _shrink_extend_edge(data, mask, spacing,
amend_head_tail, shrink_extend_mask, min_pixel,
file_prefix, log)
except Exception as e:
logging.error('shrink_extend_edge error!!!!!!', e)
traceback.print_exc()
return False, mask, mask, mask
def _shrink_extend_edge(data, mask, spacing,
amend_head_tail, shrink_extend_mask, min_pixel,
file_prefix, log):
"""
修正边缘
"""
trachea_pixel = get_diameter_pixel(diameter=2.0, spacing=spacing)
amend_pixel = get_diameter_pixel(diameter=2.0, spacing=spacing)
amend_pixel = min(int(amend_pixel), 5)
print(file_prefix, 'amend_pixel =', amend_pixel, 'trachea_pixel =', trachea_pixel)
min_pixel = max(0, min_pixel, 3 * trachea_pixel)
# 获取mask区块
coords = np.asarray(np.where(mask == 1))
if len(coords[0]) <= min_pixel:
return False, mask, mask, mask
coord_start = coords.min(axis=1)
coord_end = coords.max(axis=1) + 1
z_num = coord_end[0] - coord_start[0] + 4
y_num = coord_end[1] - coord_start[1] + 5 * amend_pixel
x_num = coord_end[2] - coord_start[2] + 5 * amend_pixel
crop_start, crop_end = get_crop_start_end(coord_start, coord_end, (z_num, y_num, x_num))
# 使计算量减少,只取包括mask的区块
new_data = data[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
new_mask = mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]].copy()
# 每层mask点数
z_point_count = np.sum(new_mask == 1, axis=(1, 2))
# 最大面层
z_max = np.argmax(z_point_count)
# 层数
z_length = len(new_mask)
# 最前面层
z_start = coords.min(axis=1)[0]
# 最后面层
z_end = coords.max(axis=1)[0]
# 最大面层到最前面层
front_z_list = list(range(z_max, z_start - crop_start[0] - 1, -1))
# 最大面层下一层到最后面层
behind_z_list = list(range(z_max + 1, z_end - crop_start[0] + 1, 1))
# 解决少层问题,目前只新增并修正前后各一层
# 新增的前面层
add_front_z_list = []
# 新增的后面层
add_behind_z_list = []
# 计算平均密度、最小密度、最大密度
average_density, min_density, max_density = \
calculate_edge_hu_threshold(new_data, new_mask, True, file_prefix=file_prefix, log=log)
if min_density < -900:
return False, mask, mask, mask
# 小于最小密度点的mask置为0
new_mask[new_data < min_density - density_reject_error_range] = 0
# 从最大面层到最前面层,检查层是否存在异常
for z in front_z_list:
if np.all(new_mask[z] == 0):
continue
# 检查层是否存在异常
if z != z_max and 0 <= z + 1 < z_length:
current_point_num = z_point_count[z]
next_point_num = z_point_count[z + 1]
# 检查最前面层
if z == z_start - crop_start[0]:
if current_point_num <= min_pixel or \
(current_point_num <= 2 * min_pixel and 5 * current_point_num < next_point_num):
new_mask[z] = 0
add_front_z_list = [z]
elif 0 <= z - 1 < z_length:
add_front_z_list = [z - 1]
break
elif current_point_num <= min_pixel:
for temp_z in range(z, -1, -1):
new_mask[temp_z] = 0
break
elif 8 * current_point_num < next_point_num:
# 不足属于异常
new_mask[z] = get_maybe_mask_2d(new_mask[z], new_mask[z + 1], trachea_pixel)
# 从最大面层下一层到最后面层,检查层是否存在异常
for z in behind_z_list:
if np.all(new_mask[z] == 0):
continue
# 检查层是否存在异常
if z != z_max and 0 <= z - 1 < z_length:
current_point_num = z_point_count[z]
prev_point_num = z_point_count[z - 1]
# 检查最后面层
if z == z_end - crop_start[0]:
if current_point_num <= min_pixel or \
(current_point_num <= 2 * min_pixel and 5 * current_point_num < prev_point_num):
new_mask[z] = 0
add_behind_z_list = [z]
elif 0 <= z + 1 < z_length:
add_behind_z_list = [z + 1]
break
elif current_point_num <= min_pixel:
for temp_z in range(z, z_length, 1):
new_mask[temp_z] = 0
break
elif 8 * current_point_num < prev_point_num:
# 不足属于异常
new_mask[z] = get_maybe_mask_2d(new_mask[z], new_mask[z - 1], trachea_pixel)
ref_mask = np.max(new_mask, axis=0)
ret, ref_mask = get_erode_mask_2d(ref_mask, trachea_pixel)
if 'linux' in sys.platform:
new_mask, shrink_mask = shrink_extend_edge_2d_pool(new_data, new_mask, ref_mask,
average_density, min_density, max_density,
crop_start, z_length,
front_z_list, behind_z_list,
amend_pixel, trachea_pixel, min_pixel)
else:
shrink_mask = np.zeros(new_mask.shape, np.uint8)
new_mask, shrink_mask = shrink_extend_edge_2d(new_data, new_mask, shrink_mask, ref_mask,
average_density, min_density, max_density,
crop_start, z_length,
front_z_list, behind_z_list,
amend_pixel, trachea_pixel, min_pixel)
# 解决少层问题
if amend_head_tail and (add_front_z_list or add_behind_z_list):
for z in add_front_z_list:
# 用后面层的mask替换当前层的mask
new_mask[z] = get_maybe_mask_2d(new_mask[z], new_mask[z + 1], trachea_pixel)
# 小于最小密度点的mask置为0
new_mask[z][new_data[z] < min_density - density_allow_error_range] = 0
logging.info('{}, {}, add_front_z_list {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_prefix, add_front_z_list + crop_start[0]))
for z in add_behind_z_list:
# 用前面层的mask替换当前层的mask
new_mask[z] = get_maybe_mask_2d(new_mask[z], new_mask[z - 1], trachea_pixel)
# 小于最小密度点的mask置为0
new_mask[z][new_data[z] < min_density - density_allow_error_range] = 0
logging.info('{}, {}, add_behind_z_list {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_prefix, add_behind_z_list + crop_start[0]))
new_mask, shrink_mask = shrink_extend_edge_2d(new_data, new_mask, shrink_mask, ref_mask,
average_density, min_density, max_density,
crop_start, z_length,
add_front_z_list, add_behind_z_list,
amend_pixel, trachea_pixel, min_pixel, True)
# 空洞问题
# 原mask
original_mask = mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
# 在shrink_mask或原mask中没有选中
condition = np.logical_or(shrink_mask == 0, original_mask == 0)
# 在new_mask中选中
condition = np.logical_and(new_mask == 1, condition)
# 密度小于最小密度
condition = np.logical_and(new_data < min_density, condition)
label = measure.label(condition, connectivity=1)
props = measure.regionprops(label)
no_select_label = [prop.label for prop in props if prop.area > trachea_pixel * trachea_pixel]
if len(no_select_label) > 0:
no_select_condition = np.logical_and(np.isin(label, no_select_label), condition)
new_mask[no_select_condition] = 0
# 删除独立的小区块
label = measure.label(new_mask, connectivity=1)
props = measure.regionprops(label)
if len(props) > 1:
props = sorted(props, key=lambda x: x.area, reverse=True)
max_area = props[0].area
no_select_label = [prop.label for prop in props if prop.area < 0.1 * max_area]
if len(no_select_label) > 0:
no_select_condition = np.isin(label, no_select_label)
new_mask[no_select_condition] = 0
# 删除异常层
remove_small_point_layer_3d(new_data, new_mask,
amend_head_tail, add_front_z_list, add_behind_z_list, min_pixel)
new_mask_result = np.zeros(mask.shape, np.uint8)
new_mask_result[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]] = new_mask
if not shrink_extend_mask:
return True, new_mask_result, new_mask_result, new_mask_result
shrink_mask_result = np.zeros(mask.shape, np.uint8)
shrink_mask_result[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]] = shrink_mask
return True, new_mask_result, new_mask_result, shrink_mask_result
def shrink_extend_edge_2d_pool(new_data, new_mask, ref_mask,
average_density, min_density, max_density,
crop_start, z_length,
front_z_list, behind_z_list,
amend_pixel, trachea_pixel, min_pixel):
def callback(result):
(z, mask) = result
new_mask[z] = mask
pool_size = min(len(front_z_list) + len(behind_z_list), 5)
pool = Pool(pool_size)
# 从最大面层到最前面层,收缩不为结节的点
for z in front_z_list:
original_z = z + crop_start[0]
pool.apply_async(batch_shrink_edge_2d,
args=(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
amend_pixel, trachea_pixel),
callback=callback)
# 从最大面层下一层到最后面层,收缩不为结节的点
for z in behind_z_list:
original_z = z + crop_start[0]
pool.apply_async(batch_shrink_edge_2d,
args=(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
amend_pixel, trachea_pixel),
callback=callback)
pool.close()
pool.join()
shrink_mask = new_mask.copy()
pool = Pool(pool_size)
# 从最大面层到最前面层,扩展为结节的点
for z in front_z_list:
original_z = z + crop_start[0]
small_data = new_data[z - 1] if 0 <= z - 1 < z_length else new_data[z]
small_mask = new_mask[z - 1] if 0 <= z - 1 < z_length else new_mask[z]
big_data = new_data[z + 1] if 0 <= z + 1 < z_length else new_data[z]
big_mask = new_mask[z + 1] if 0 <= z + 1 < z_length else new_mask[z]
new_min_pixel = max(0.075 * np.sum(new_mask[z + 1] == 1) if 0 <= z + 1 < z_length else 0, min_pixel)
new_mask[z] = batch_extend_edge_3d(new_data[z], new_mask[z], original_z,
average_density, min_density, max_density,
small_data, small_mask, big_data, big_mask)
pool.apply_async(batch_extend_edge_2d,
args=(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
big_mask, ref_mask, shrink_mask[z],
amend_pixel, trachea_pixel, new_min_pixel),
callback=callback)
# 从最大面层下一层到最后面层,扩展为结节的点
for z in behind_z_list:
original_z = z + crop_start[0]
small_data = new_data[z + 1] if 0 <= z + 1 < z_length else new_data[z]
small_mask = new_mask[z + 1] if 0 <= z + 1 < z_length else new_mask[z]
big_data = new_data[z - 1] if 0 <= z - 1 < z_length else new_data[z]
big_mask = new_mask[z - 1] if 0 <= z - 1 < z_length else new_mask[z]
new_min_pixel = max(0.075 * np.sum(new_mask[z - 1] == 1) if 0 <= z - 1 < z_length else 0, min_pixel)
new_mask[z] = batch_extend_edge_3d(new_data[z], new_mask[z], original_z,
average_density, min_density, max_density,
small_data, small_mask, big_data, big_mask)
pool.apply_async(batch_extend_edge_2d,
args=(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
big_mask, ref_mask, shrink_mask[z],
amend_pixel, trachea_pixel, new_min_pixel),
callback=callback)
pool.close()
pool.join()
return new_mask, shrink_mask
def shrink_extend_edge_2d(new_data, new_mask, shrink_mask, ref_mask,
average_density, min_density, max_density,
crop_start, z_length,
front_z_list, behind_z_list,
amend_pixel, trachea_pixel, min_pixel, is_head_tail=False):
# 从最大面层到最前面层,收缩不为结节的点
for z in front_z_list:
original_z = z + crop_start[0]
_, new_mask[z] = batch_shrink_edge_2d(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
amend_pixel, trachea_pixel)
shrink_mask[z] = new_mask[z].copy()
# 从最大面层下一层到最后面层,收缩不为结节的点
for z in behind_z_list:
original_z = z + crop_start[0]
_, new_mask[z] = batch_shrink_edge_2d(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
amend_pixel, trachea_pixel)
shrink_mask[z] = new_mask[z].copy()
# 从最大面层到最前面层,扩展为结节的点
for z in front_z_list:
original_z = z + crop_start[0]
small_data = new_data[z - 1] if 0 <= z - 1 < z_length else new_data[z]
small_mask = new_mask[z - 1] if 0 <= z - 1 < z_length else new_mask[z]
big_data = new_data[z + 1] if 0 <= z + 1 < z_length else new_data[z]
big_mask = new_mask[z + 1] if 0 <= z + 1 < z_length else new_mask[z]
new_min_pixel = max(0.075 * np.sum(new_mask[z + 1] == 1) if 0 <= z + 1 < z_length else 0, min_pixel)
new_mask[z] = batch_extend_edge_3d(new_data[z], new_mask[z], original_z,
average_density, min_density, max_density,
small_data, small_mask, big_data, big_mask)
_, new_mask[z] = batch_extend_edge_2d(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
big_mask, ref_mask, shrink_mask[z],
amend_pixel, trachea_pixel, new_min_pixel, is_head_tail)
# 从最大面层下一层到最后面层,扩展为结节的点
for z in behind_z_list:
original_z = z + crop_start[0]
small_data = new_data[z + 1] if 0 <= z + 1 < z_length else new_data[z]
small_mask = new_mask[z + 1] if 0 <= z + 1 < z_length else new_mask[z]
big_data = new_data[z - 1] if 0 <= z - 1 < z_length else new_data[z]
big_mask = new_mask[z - 1] if 0 <= z - 1 < z_length else new_mask[z]
new_min_pixel = max(0.075 * np.sum(new_mask[z - 1] == 1) if 0 <= z - 1 < z_length else 0, min_pixel)
new_mask[z] = batch_extend_edge_3d(new_data[z], new_mask[z], original_z,
average_density, min_density, max_density,
small_data, small_mask, big_data, big_mask)
_, new_mask[z] = batch_extend_edge_2d(new_data[z], new_mask[z], original_z, z,
average_density, min_density, max_density,
big_mask, ref_mask, shrink_mask[z],
amend_pixel, trachea_pixel, new_min_pixel, is_head_tail)
return new_mask, shrink_mask
def get_kernel_and_iterations(mask_point_count, trachea_pixel, min_iterations=2):
"""
根据mask点数与血管像素点数获取kernel与迭代次数
"""
if mask_point_count <= 10 ** 2:
kernel = kernel3
next_kernel = kernel5
iterations = 2
elif mask_point_count <= 17 ** 2:
kernel = kernel5
next_kernel = kernel7
iterations = 2
elif mask_point_count <= 25 ** 2:
kernel = kernel5
next_kernel = kernel7
iterations = trachea_pixel / 4
else:
kernel = kernel7
next_kernel = kernel7
iterations = trachea_pixel / 6
iterations = max(int(np.ceil(iterations)), min_iterations)
return kernel, next_kernel, iterations
def get_maybe_mask_2d(small_mask, big_mask, trachea_pixel):
"""
根据small_mask,big_mask获取包含small_mask的估计mask
"""
ret, erode_mask = get_erode_mask_2d(big_mask, trachea_pixel)
if ret:
small_mask = np.bitwise_or(erode_mask, small_mask)
return small_mask
def get_erode_mask_2d(mask, trachea_pixel, min_iterations=2, min_ratio=0.2, ratio=0.8):
"""
获取侵蚀后与mask的交集
"""
mask_point_count = np.sum(mask == 1)
if mask_point_count > 8 ** 2:
kernel, next_kernel, original_iterations = get_kernel_and_iterations(mask_point_count, trachea_pixel)
for iterations in range(original_iterations, 0, -1):
next_mask = cv2.erode(mask, next_kernel, iterations=iterations)
new_mask = cv2.erode(mask, kernel, iterations=iterations)
new_mask = np.bitwise_or(next_mask, new_mask)
new_mask = np.bitwise_and(mask, new_mask)
new_mask_point_count = np.sum(new_mask == 1)
if new_mask_point_count > ratio * mask_point_count or \
(new_mask_point_count > min_ratio * mask_point_count and iterations == min_iterations):
return True, new_mask
return False, mask
def get_erode_dilate_mask_2d(mask, trachea_pixel, min_iterations=2, min_ratio=0.2, ratio=0.8):
"""
获取侵蚀膨胀后与mask的交集
"""
mask_point_count = np.sum(mask == 1)
if mask_point_count > 8 ** 2:
kernel, next_kernel, original_iterations = get_kernel_and_iterations(mask_point_count, trachea_pixel)
for iterations in range(original_iterations, 0, -1):
next_mask = cv2.erode(mask, next_kernel, iterations=iterations)
next_mask = cv2.dilate(next_mask, next_kernel, iterations=iterations)
new_mask = cv2.erode(mask, kernel, iterations=iterations)
new_mask = cv2.dilate(new_mask, kernel, iterations=iterations)
new_mask = np.bitwise_or(next_mask, new_mask)
new_mask = np.bitwise_and(mask, new_mask)
new_mask_point_count = np.sum(new_mask == 1)
if new_mask_point_count > ratio * mask_point_count or \
(new_mask_point_count > min_ratio * mask_point_count and iterations == min_iterations):
return True, new_mask
return False, mask
def fill_min_mask_2d(max_mask, min_mask, default_distance=-1.0):
"""
填充最小mask边缘
"""
change_mask = max_mask - min_mask
if np.sum(change_mask == 1) > 0:
change_coords = np.asarray(np.where(change_mask == 1)).T
_, contours, _ = cv2.findContours(min_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
for point in change_coords:
distance = cv2.pointPolygonTest(contour, (point[1], point[0]), True)
if distance > default_distance:
min_mask[point[0], point[1]] = 1
return min_mask
def fill_error_range_point(data, mask, z, density, error_density=0, edge_mask=None):
"""
填充在误差范围内凸集中的点
"""
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
if len(contours) > 1:
# 设置contour获取包含的所有点
temp_mask = np.zeros(mask.shape, np.uint8)
cv2.drawContours(temp_mask, [contour], -1, 1, -1)
else:
temp_mask = mask
temp_mask = morphology.convex_hull_object(temp_mask, neighbors=8)
temp_mask = temp_mask.astype(np.uint8)
if edge_mask is not None:
temp_mask = np.bitwise_and(edge_mask, temp_mask)
condition = np.logical_and(data >= density - error_density, temp_mask == 1)
mask[condition] = 1
def fill_3d(mask):
"""
填充mask
"""
for z in range(len(mask)):
if np.all(mask[z] == 0):
continue
fill_2d(mask[z], z)
def fill_2d(mask, z):
"""
填充mask
"""
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(mask, contours, -1, 1, -1)
def remove_small_objects_3d(mask, ratio=0.15, min_pixel=5):
"""
删除小区块
"""
for z in range(len(mask)):
if np.all(mask[z] == 0):
continue
mask[z] = remove_small_objects_2d(mask[z], z, ratio=ratio, min_pixel=min_pixel)
return mask
def remove_small_objects_2d(mask, z, ratio=0.15, min_pixel=5):
"""
删除小区块
"""
mask_point_count = int(np.sum(mask == 1))
min_size = int(max(ratio * mask_point_count, min_pixel, 5))
input_mask = mask.astype(np.bool)
result_mask = morphology.remove_small_objects(input_mask, min_size=min_size)
result_mask = result_mask.astype(np.uint8)
if np.sum(result_mask == 1) > 0:
mask = result_mask
return mask
def remove_small_area_2d(mask, z, ratio=0.15):
"""
检查区块,删除比率小的区块
"""
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) > 1:
contours = sorted(contours, key=cv2.contourArea, reverse=True)
for contour_index, contour in enumerate(contours):
if contour_index == 0:
max_area = cv2.contourArea(contour)
elif cv2.contourArea(contour) < ratio * max_area:
cv2.drawContours(mask, [contour], -1, 0, -1)
return mask
def remove_new_extend_area_2d(new_mask, original_mask, z, ratio=0.15):
"""
删除新扩展的独立区块
"""
_, original_contours, _ = cv2.findContours(original_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(original_contours) == 0:
return new_mask
_, new_contours, _ = cv2.findContours(new_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(new_contours) > 1:
new_contours = sorted(new_contours, key=cv2.contourArea, reverse=True)
for new_contour_index, new_contour in enumerate(new_contours):
new_check_mask = np.zeros(new_mask.shape, np.uint8)
cv2.drawContours(new_check_mask, [new_contour], -1, 1, -1)
found = False
for original_contour_index, original_contour in enumerate(original_contours):
original_check_mask = np.zeros(new_mask.shape, np.uint8)
cv2.drawContours(original_check_mask, [original_contour], -1, 1, -1)
check_mask = np.bitwise_and(new_check_mask, original_check_mask)
intersection_count = np.sum(check_mask == 1)
check_mask = np.bitwise_or(new_check_mask, original_check_mask)
union_count = np.sum(check_mask == 1)
if intersection_count > 0 and union_count > 0 and intersection_count / union_count > ratio:
found = True
break
# 删除新扩展的独立区块
if not found:
cv2.drawContours(new_mask, [new_contour], -1, 0, -1)
return new_mask
def remove_trachea_2d(mask, z, trachea_pixel, retain_first=False):
"""
检查形态,删除长宽比异常的区块
"""
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) > 1:
contours = sorted(contours, key=cv2.contourArea, reverse=True)
for contour_index, contour in enumerate(contours):
if retain_first and contour_index == 0:
continue
if len(contours) > 1:
# 设置contour获取包含的所有点
check_mask = np.zeros(mask.shape, np.uint8)
cv2.drawContours(check_mask, [contour], -1, 1, -1)
else:
check_mask = mask
if is_trachea(check_mask, z, trachea_pixel, contour_index, contour):
cv2.drawContours(mask, [contour], -1, 0, -1)
return mask
def is_trachea(mask, z, trachea_pixel, contour_index, contour):
"""
通过形态判断是否为异常的区块
"""
rect_min = cv2.minAreaRect(contour)
((_, _), (width, height), angle) = rect_min
if width < 2 or height < 2:
return True
min_length = min(width, height)
max_length = max(width, height)
if (contour_index == 0 and min_length <= trachea_pixel and max_length / min_length > 2.0) or \
(contour_index > 0 and min_length <= 1.5 * trachea_pixel and max_length / min_length > 2.0) or \
max_length / min_length > 3.0:
return True
mask_point_count = np.sum(mask == 1)
ret, erode_mask = get_erode_mask_2d(mask, trachea_pixel)
if not ret:
if min_length <= 1.5 * trachea_pixel and max_length / min_length > 2.0:
print('z =', z,
'min_length =', np.round(min_length, 2),
'max_length =', np.round(max_length, 2),
'mask_point_count =', mask_point_count,
'area =', np.round(width * height, 2))
return True
return False
def remove_small_point_layer_3d(data, mask,
amend_head_tail, add_front_z_list, add_behind_z_list, min_pixel):
"""
删除异常层
"""
# 每层mask点数
z_point_count = np.sum(mask == 1, axis=(1, 2))
# 最大面层
z_max = np.argmax(z_point_count)
coords = np.asarray(np.where(mask == 1))
if len(coords[0]) > 0:
# 最前面层
z_start = coords.min(axis=1)[0]
# 最后面层
z_end = coords.max(axis=1)[0]
if z_end - z_start > 2 and z_start != z_max:
z = z_start
if z != z_max and 0 <= z + 1 < len(mask):
current_point_num = z_point_count[z]
next_point_num = z_point_count[z + 1]
if amend_head_tail and add_front_z_list and len(add_front_z_list) > 0 and add_front_z_list[0] == z:
if current_point_num < 2 * min_pixel or 5 * current_point_num < next_point_num:
mask[z] = 0
elif current_point_num < 3 * min_pixel and 8 * current_point_num < next_point_num:
mask[z] = 0
if z_end - z_start > 2 and z_end != z_max:
z = z_end
if z != z_max and 0 <= z - 1 < len(mask):
current_point_num = z_point_count[z]
prev_point_num = z_point_count[z - 1]
if amend_head_tail and add_behind_z_list and len(add_behind_z_list) > 0 and add_behind_z_list[0] == z:
if current_point_num < 2 * min_pixel or 5 * current_point_num < prev_point_num:
mask[z] = 0
elif current_point_num < 3 * min_pixel and 8 * current_point_num < prev_point_num:
mask[z] = 0
def sum_nodule_point(data, min_density, max_density):
"""
统计在某密度范围内的点数量
"""
return np.sum(np.logical_and(data >= min_density, data <= max_density))
def is_near_distance(source, target1, target2, ratio=1.0):
"""
判断是否离目标1的距离更近
"""
if (ratio * (source - target1)) ** 2 < (source - target2) ** 2:
return True
return False
def calculate_density(data, mask, z, point, kernel_radius, min_density=None):
"""
计算选中区块与未选中区块的平均密度
"""
density = data[point[0], point[1]]
around_data = get_region_data(data, point, kernel_radius)
if len(around_data.ravel()) < (2 * kernel_radius + 1) ** 2:
return density, density, 0, 0
around_mask = get_region_data(mask, point, kernel_radius).copy()
# 当前点不参与计算
around_mask[kernel_radius, kernel_radius] = -1
if min_density is not None:
# 小于最小密度的点不参与计算
around_mask[around_data < min_density] = -1
num1 = np.sum(around_mask == 1)
num0 = np.sum(around_mask == 0)
density1 = np.mean(around_data[around_mask == 1]) if num1 > 0 else density
density0 = np.mean(around_data[around_mask == 0]) if num0 > 0 else density
return density1, density0, num1, num0
def is_shrink_point_density(data, mask, z, point, kernel_radius, min_density=None, error_thred=0):
"""
判断该点是否可以收缩
"""
density = data[point[0], point[1]]
density1, density0, num1, num0 = calculate_density(data, mask, z, point, kernel_radius, min_density)
if num1 <= 0 or num0 <= 0:
# 数据不正确
return 0
if is_near_distance(density, density0, density1):
# 离未选中区块更近
return 1
if np.abs(density - density0) <= error_thred:
# 在未选中区块的误差范围内
return 1
return 0
def is_extend_point_density(data, mask, z, point, kernel_radius, min_density=None, error_thred=0, ref_mask=None):
"""
判断该点是否可以扩展
"""
density = data[point[0], point[1]]
density1, density0, num1, num0 = calculate_density(data, mask, z, point, kernel_radius, min_density)
if num1 <= 0 or num0 <= 0:
# 数据不正确
return 0
if ref_mask is not None and ref_mask[point[0], point[1]] == 0 and \
density0 > density1 + 20 and density > density1 + 20:
# 不在ref_mask范围内,不向密度大的方向上扩展
return 0
if (density > density1 + 20 and density > density0 + 20) or \
(density < density1 - 20 and density < density0 - 20):
# 该点的密度与两边反向,不处理该点
return 0
if is_near_distance(density, density1, density0):
# 离选中区块更近
return 1
if np.abs(density - density1) <= error_thred:
# 在选中区块的误差范围内
return 1
return 0
def shrink_edge_point_11(data, mask, z,
average_density, min_density, max_density,
amend_pixel):
"""
收缩边缘点,该点在它的周围选中的点中为最大密度的点
"""
for _ in range(amend_pixel):
change = False
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
around_points = get_contour_points(contour)
around_data = data[around_points[:, 0], around_points[:, 1]]
idx_sorted = np.argsort(around_data)[::-1]
for point in around_points[idx_sorted]:
if mask[point[0], point[1]] == 0:
continue
density = data[point[0], point[1]]
if density < max_density - 20:
# 该点密度小于最大密度,跳过
continue
check_data = get_region_data(data, point, 2)
if np.sum(check_data < min_density) > 0:
# 周围有点密度小于最小密度,跳过
continue
check_mask = get_region_data(mask, point, 2)
check_data_1 = check_data[check_mask == 1]
check_data_0 = check_data[check_mask == 0]
if np.sum(check_data_1 > density) == 0 and np.sum(check_data_0 > density) > 0:
# 周围选中的点中密度都小于该点密度并且周围未选中的点中密度存在大于该点密度,mask置于0
mask[point[0], point[1]] = 0
change = True
continue
if not change:
break
def shrink_edge_point_21(data, mask, z,
average_density, min_density, max_density,
amend_pixel, error_thred=0):
"""
收缩边缘点
"""
for _ in range(amend_pixel):
change = False
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
around_points = get_contour_points(contour)
for point in around_points:
if mask[point[0], point[1]] == 0:
continue
density = data[point[0], point[1]]
if min_density + density_reject_error_range <= density <= max_density - density_reject_error_range:
continue
check_data = get_region_data(data, point, 1)
check_num = len(check_data.ravel())
nodule_point_num = sum_nodule_point(check_data, min_density, max_density)
if min_density + density_allow_error_range <= density <= max_density - density_allow_error_range and \
nodule_point_num > 0.5 * check_num:
continue
elif min_density <= density <= max_density and \
nodule_point_num > 0.8 * check_num:
continue
if density < min_density and np.sum(check_data < min_density) > 0.5 * check_num:
# 该点密度小于最小密度,并且周围的点中密度有一半以上小于最小密度,mask置于0
mask[point[0], point[1]] = 0
change = True
continue
# 判断该点是否可以收缩
new_error_thred = error_thred
if density < min_density - density_allow_error_range:
new_error_thred += density_allow_error_range
vote_num = is_shrink_point_density(data, mask, z, point, 1,
error_thred=new_error_thred) + \
is_shrink_point_density(data, mask, z, point, 2,
error_thred=new_error_thred)
if vote_num >= 2:
# 该点可以收缩,mask置于0
mask[point[0], point[1]] = 0
change = True
continue
if not change:
break
def batch_shrink_edge_2d(data, mask, z, index,
average_density, min_density, max_density,
amend_pixel, trachea_pixel):
"""
批量收缩不为结节的点
"""
# 检查形态,删除长宽比异常的区块
mask = remove_trachea_2d(mask, z, trachea_pixel, retain_first=True)
# 填充mask
fill_2d(mask, z)
# 获取核mask
ret, kernel_mask = get_erode_mask_2d(mask, trachea_pixel)
if ret:
# 填充在误差范围内凸集中的点
fill_error_range_point(data, kernel_mask, z, min_density, error_density=density_allow_error_range)
mask[kernel_mask == 1] = 1
# 收缩前mask
original_shrink_mask = mask.copy()
# 收缩边缘点,该点在它的周围选中的点中为最大密度的点
shrink_edge_point_11(data, mask, z,
average_density, min_density, max_density,
2)
# 收缩边缘点
shrink_edge_point_21(data, mask, z,
average_density, min_density, max_density,
amend_pixel)
# 获取最小mask
ret, min_mask = get_erode_dilate_mask_2d(mask, trachea_pixel, ratio=0.9)
if ret:
# 收缩与扩展边缘点, 扩展因侵蚀膨胀后丢失的为结节的点
extend_edge_point_32(data, min_mask, z,
average_density, min_density, max_density,
edge_mask=mask, shrink_mask=None, error_thred=50)
mask = min_mask
# 填充mask边缘
mask = fill_min_mask_2d(original_shrink_mask, mask)
# 检查区块,删除比率小的区块
mask = remove_small_area_2d(mask, z)
return index, mask
def batch_extend_edge_3d(data, mask, z,
average_density, min_density, max_density,
small_data, small_mask, big_data, big_mask):
"""
根据上下层,批量扩展为结节的点
"""
if np.sum(small_mask == 1) > 0 and np.sum(big_mask == 1) > 0:
base_condition_data = np.logical_and(data >= min_density, data <= max_density)
base_condition_log1 = np.logical_and(base_condition_data, mask == 0)
big_condition_data = np.logical_and(big_data >= min_density, big_data <= max_density)
big_condition_log1 = np.logical_and(big_condition_data, big_mask == 1)
small_condition_data = np.logical_and(small_data >= min_density, small_data <= max_density)
small_condition_log1 = np.logical_and(small_condition_data, small_mask == 1)
condition = np.logical_and(big_condition_log1, base_condition_log1)
condition = np.logical_and(small_condition_log1, condition)
mask[condition] = 1
return mask
def extend_edge_point_11(data, mask, z,
average_density, min_density, max_density,
amend_pixel):
"""
扩展为结节的点
"""
increase_points = calculate_increase_points(amend_pixel)
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
# 获取相关的点,只处理mask为0的点
around_points = batch_get_around_points(mask, contour, increase_points, select_value=0)
around_data = data[around_points[:, 0], around_points[:, 1]]
# 获取大于等于最小密度与小于等于最大密度的点
around_min_points = around_points[np.logical_and(around_data >= min_density + density_allow_error_range,
around_data <= max_density - density_allow_error_range)]
# 大于等于最小密度与小于等于最大密度的点mask置为1
mask[around_min_points[:, 0], around_min_points[:, 1]] = 1
# 获取在误差范围内的小于最小密度的点,根据条件设置mask的值
around_error_points = around_points[np.logical_and(around_data > min_density - density_allow_error_range,
around_data < min_density + density_allow_error_range)]
if len(around_error_points) > 0:
around_error_points = np.unique(around_error_points, axis=0)
for point in around_error_points:
check_data = get_region_data(data, point, 1)
if sum_nodule_point(check_data, min_density - density_allow_error_range, max_density) >= 5:
mask[point[0], point[1]] = 1
# 获取在误差范围内的大于最大密度的点,根据条件设置mask的值
around_error_points = around_points[np.logical_and(around_data > max_density - density_allow_error_range,
around_data < max_density + density_allow_error_range)]
if len(around_error_points) > 0:
around_error_points = np.unique(around_error_points, axis=0)
for point in around_error_points:
check_data = get_region_data(data, point, 1)
if sum_nodule_point(check_data, min_density, max_density + density_allow_error_range) >= 5:
mask[point[0], point[1]] = 1
def extend_edge_point_12(data, mask, z,
average_density, min_density, max_density,
amend_pixel):
"""
扩展为结节的点
"""
increase_points = calculate_increase_points(amend_pixel)
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
# 获取相关的点,只处理mask为0的点
around_points = batch_get_around_points(mask, contour, increase_points, select_value=0)
around_data = data[around_points[:, 0], around_points[:, 1]]
# 获取大于等于最小密度与小于等于最大密度的点
around_min_points = around_points[np.logical_and(around_data >= min_density - density_allow_error_range,
around_data <= max_density + density_allow_error_range)]
# 大于等于最小密度与小于等于最大密度的点mask置为1
mask[around_min_points[:, 0], around_min_points[:, 1]] = 1
def extend_edge_point_21(data, mask, z,
average_density, min_density, max_density,
extend_mask):
"""
扩展为结节的点,边缘点
"""
extend_coords = np.asarray(np.where(extend_mask == 1)).T
for point in extend_coords:
if mask[point[0], point[1]] == 1:
continue
check_data = get_region_data(data, point, 1)
check_mask = get_region_data(mask, point, 1)
lt_min_density_num1 = np.sum(check_data < min_density)
lt_min_density_num2 = np.sum(check_data < min_density + density_allow_error_range)
if (lt_min_density_num1 > 0 or lt_min_density_num2 > 2) and \
sum_nodule_point(check_data, min_density, max_density) > 2:
# 周围有点的密度小于最小密度,该点设为边缘点,mask置于1
mask[point[0], point[1]] = 1
# 周围所有点mask也置于1
check_mask[...] = 1
elif lt_min_density_num2 + \
np.sum(check_data > max_density - density_allow_error_range) == len(check_data.ravel()):
# 周围点的密度要么小于最小密度要么大于最大密度,mask置于1
mask[point[0], point[1]] = 1
def extend_edge_point_31(data, mask, z,
average_density, min_density, max_density,
amend_pixel, edge_mask=None, ref_mask=None, error_thred=0):
"""
扩展为结节的点
"""
for _ in range(amend_pixel):
change = False
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
# 获取相关的点,只处理mask为0的点
around_points = batch_get_around_points(mask, contour, increase_points1, select_value=0)
if len(around_points) == 0:
continue
around_points = np.unique(around_points, axis=0)
for point in around_points:
if mask[point[0], point[1]] == 1:
continue
if edge_mask is not None and edge_mask[point[0], point[1]] == 0:
continue
check_mask = get_region_data(mask, point, 1)
if np.sum(check_mask == 1) <= 2:
continue
density = data[point[0], point[1]]
if not (min_density - density_allow_error_range <= density <= max_density + density_allow_error_range):
# 该点密度不在密度范围内,跳过
continue
check_data = get_region_data(data, point, 1)
check_num = len(check_data.ravel())
nodule_point_num = sum_nodule_point(check_data, min_density, max_density)
check_data_1 = check_data[check_mask == 1]
nodule_point_num_1 = sum_nodule_point(check_data_1,
min_density + density_allow_error_range,
max_density - density_allow_error_range)
if min_density + density_allow_error_range <= density <= max_density - density_allow_error_range and \
nodule_point_num_1 > 3:
# 周围的点密度在密度范围内,mask置于1
mask[point[0], point[1]] = 1
change = True
continue
elif min_density <= density <= max_density and \
nodule_point_num == check_num and \
nodule_point_num_1 > 3:
# 周围的点密度在密度范围内,mask置于1
mask[point[0], point[1]] = 1
change = True
continue
# 判断该点是否可以扩展
new_error_thred = error_thred
vote_num = is_extend_point_density(data, mask, z, point, 1,
error_thred=new_error_thred, ref_mask=ref_mask) + \
is_extend_point_density(data, mask, z, point, 2,
error_thred=new_error_thred, ref_mask=ref_mask)
if vote_num >= 2:
# 该点可以扩展,mask置于1
mask[point[0], point[1]] = 1
change = True
continue
if not change:
break
def blank_outside_edge_3d(mask, times=1):
for z in range(len(mask)):
if np.sum(mask[z] == 1) == 0:
continue
blank_outside_edge_2d(mask[z], times)
def blank_outside_edge_2d(mask, times=1):
for _ in range(times):
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
around_points = get_contour_points(contour)
if len(around_points) > 0:
mask[around_points[:, 0], around_points[:, 1]] = 0
def extend_edge_point_32(data, mask, z,
average_density, min_density, max_density,
edge_mask=None, shrink_mask=None, error_thred=0):
"""
收缩与扩展边缘点
"""
mask_point_count = np.sum(mask == 1)
if mask_point_count < 8 ** 2:
return
if edge_mask is None:
edge_mask = mask.copy()
times = 1
blank_outside_edge_2d(mask, times)
for i in range(times + 1):
change = False
_, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
for contour in contours:
# 获取相关的点,只处理mask为0的点
around_points = batch_get_around_points(mask, contour, increase_points1, select_value=0)
if len(around_points) == 0:
continue
around_points = np.unique(around_points, axis=0)
for point in around_points:
if mask[point[0], point[1]] == 1:
continue
if edge_mask is not None and edge_mask[point[0], point[1]] == 0:
continue
check_mask = get_region_data(mask, point, 1)
if np.sum(check_mask == 1) <= 2:
continue
if shrink_mask is not None and shrink_mask[point[0], point[1]] == 1:
# 该点在shrink_mask内,mask置于1
mask[point[0], point[1]] = 1
continue
density = data[point[0], point[1]]
if min_density + density_allow_error_range <= density <= max_density - density_allow_error_range:
# 该点密度在密度范围内,mask置于1
mask[point[0], point[1]] = 1
change = True
continue
# 判断该点是否可以扩展
new_error_thred = error_thred
vote_num = is_extend_point_density(data, mask, z, point, 1,
error_thred=new_error_thred) + \
is_extend_point_density(data, mask, z, point, 2,
error_thred=new_error_thred)
if vote_num >= 2:
# 该点可以扩展,mask置于1
mask[point[0], point[1]] = 1
change = True
continue
if not change:
break
def batch_extend_edge_2d(data, mask, z, index,
average_density, min_density, max_density,
big_mask, ref_mask, shrink_mask,
amend_pixel, trachea_pixel, min_pixel, is_head_tail=False):
"""
批量扩展为结节的点
"""
# 获取最大扩展区块
max_mask = mask.copy()
extend_edge_point_11(data, max_mask, z,
average_density, min_density, max_density,
amend_pixel)
# 删除小区块
max_mask = remove_small_objects_2d(max_mask, z, min_pixel=min_pixel)
# 删除新扩展的独立区块
max_mask = remove_new_extend_area_2d(max_mask, shrink_mask, z)
# 检查形态,删除长宽比异常的区块
max_mask = remove_trachea_2d(max_mask, z, trachea_pixel, retain_first=not is_head_tail)
if np.sum(max_mask == 1) == 0:
return index, max_mask
# 获取核mask
ret, kernel_mask = get_erode_mask_2d(max_mask, trachea_pixel)
if ret:
# 填充在误差范围内凸集中的点
fill_error_range_point(data, kernel_mask, z, min_density, error_density=density_allow_error_range,
edge_mask=big_mask)
max_mask[kernel_mask == 1] = 1
# 多扩展几圈计算边缘
new_mask = max_mask.copy()
extend_edge_point_12(data, new_mask, z,
average_density, min_density, max_density,
2)
# 删除新扩展的独立区块
new_mask = remove_new_extend_area_2d(new_mask, shrink_mask, z)
# 获取所有扩展的点,并判断扩展的点是否为边缘点
total_mask = mask.copy()
extend_mask = new_mask - total_mask
extend_mask[np.logical_and(data > min_density + density_reject_error_range,
data < max_density - density_reject_error_range)] = 0
extend_edge_point_21(data, total_mask, z,
average_density, min_density, max_density,
extend_mask)
# 填充边缘内的所有点
fill_2d(total_mask, z)
# 获取最大扩展区块与边缘内区块的交集
edge_mask = np.bitwise_and(max_mask, total_mask)
if is_head_tail and not np.all(max_mask == edge_mask):
# 没有找到边界
mask[...] = 0
shrink_mask[...] = 0
return index, mask
# 真正扩展为结节的点
extend_edge_point_31(data, mask, z,
average_density, min_density, max_density,
amend_pixel, edge_mask=edge_mask, ref_mask=ref_mask)
# 获取核mask
ret, kernel_mask = get_erode_mask_2d(mask, trachea_pixel)
if ret:
# 填充在误差范围内凸集中的点
fill_error_range_point(data, kernel_mask, z, min_density, error_density=density_allow_error_range,
edge_mask=big_mask)
mask[kernel_mask == 1] = 1
# 再次扩展为结节的点
extend_edge_point_31(data, mask, z,
average_density, min_density, max_density,
2, edge_mask=edge_mask)
# 收缩边缘点,该点在它的周围选中的点中为最大密度的点
shrink_edge_point_11(data, mask, z,
average_density, min_density, max_density,
2)
# 收缩与扩展边缘点
extend_edge_point_32(data, mask, z,
average_density, min_density, max_density,
shrink_mask=shrink_mask, error_thred=0)
# 获取最小mask
ret, min_mask = get_erode_dilate_mask_2d(mask, trachea_pixel)
# 填充mask边缘
mask = fill_min_mask_2d(mask, min_mask)
mask[shrink_mask == 1] = 1
# 删除小区块
mask = remove_small_objects_2d(mask, z, min_pixel=min_pixel)
# 删除新扩展的独立区块
mask = remove_new_extend_area_2d(mask, shrink_mask, z)
# 检查区块,删除比率小的区块
mask = remove_small_area_2d(mask, z)
# 填充mask
fill_2d(mask, z)
return index, mask
def analysis_edge_hu_threshold(nodule_data, file_prefix='', log=False):
average_density = int(np.round(np.mean(nodule_data)))
min_density = int(np.min(nodule_data))
max_density = int(np.max(nodule_data))
left_data = nodule_data[nodule_data <= average_density]
left_point_count = len(left_data)
sorted_left_data = left_data[np.argsort(left_data)]
max_threshold_index = max(int(0.3 * left_point_count), 20 if left_point_count > 20 else left_point_count - 1)
min_threshold = min_density
max_threshold = sorted_left_data[max_threshold_index]
left_average_space = np.round((np.max(left_data) - np.min(left_data)) / (left_point_count + 1), 5)
left_max_space = max(4 * left_average_space, 10 if left_average_space > 0.5 else 5)
if left_average_space < 0.05:
left_max_space = 50 * left_average_space
elif left_average_space < 0.2:
left_max_space = 20 * left_average_space
left_max_space = np.round(left_max_space, 5)
left_values, left_counts = np.unique(left_data, return_counts=True)
left_spaces = (left_values[1:] - left_values[:-1]) / left_counts[1:]
left_coords = np.where(left_spaces >= left_max_space)[0]
if log:
print(file_prefix, 'left_min_density =', np.min(left_data),
'left_max_density =', np.max(left_data),
'left_average_space =', left_average_space,
'left_max_space =', left_max_space,
'left_point_count =', left_point_count)
for left_min_index in left_coords:
left_min_value = left_values[left_min_index + 1]
left_remove_num = np.sum(left_data < left_min_value)
if left_remove_num < max(0.1 * left_point_count, 10):
min_density = left_min_value
if log:
print(file_prefix, 'left_min_index =', left_min_index,
'left_space =', np.round(left_spaces[left_min_index], 5),
'left_min_value =', left_min_value,
'left_remove_num =', left_remove_num)
if left_remove_num < max(0.05 * left_point_count, 5):
min_threshold = min_density
if max_threshold < min_threshold:
max_threshold = min_threshold
if log:
print(file_prefix, 'min_threshold =', min_threshold, 'max_threshold =', max_threshold)
right_data = nodule_data[nodule_data >= average_density]
right_point_count = len(right_data)
right_average_space = np.round((np.max(right_data) - np.min(right_data)) / (right_point_count + 1), 5)
right_max_space = max(4 * right_average_space, 10 if right_average_space > 0.5 else 5)
if right_average_space < 0.05:
right_max_space = 50 * right_average_space
elif right_average_space < 0.2:
right_max_space = 20 * right_average_space
right_max_space = np.round(right_max_space, 5)
right_values, right_counts = np.unique(right_data, return_counts=True)
right_spaces = (right_values[1:] - right_values[:-1]) / right_counts[1:]
right_coords = np.where(right_spaces >= right_max_space)[0]
if log:
print(file_prefix, 'right_min_density =', np.min(right_data),
'right_max_density =', np.max(right_data),
'right_average_space =', right_average_space,
'right_max_space =', right_max_space,
'right_point_count =', right_point_count)
for right_max_index in right_coords:
right_max_value = right_values[right_max_index]
right_remove_num = np.sum(right_data > right_max_value)
if right_remove_num < max(0.1 * right_point_count, 10):
max_density = right_max_value
if log:
print(file_prefix, 'right_max_index =', right_max_index,
'right_space =', np.round(right_spaces[right_max_index], 5),
'right_max_value =', right_max_value,
'right_remove_num =', right_remove_num)
break
return min_density, max_density, min_threshold, max_threshold
def calculate_edge_hu_threshold(data, mask, otsu=True, hu_max=2000, file_prefix='', log=False):
# 获取结节数据
nodule_data = filter_data(data[mask == 1])
if len(nodule_data) == 0:
return hu_max, hu_max, hu_max
min_density, max_density, min_threshold, max_threshold = \
analysis_edge_hu_threshold(nodule_data, file_prefix=file_prefix, log=log)
nodule_data = nodule_data[nodule_data >= min_density]
nodule_data = nodule_data[nodule_data <= max_density]
if len(nodule_data) == 0:
return hu_max, hu_max, hu_max
average_density = int(np.round(np.mean(nodule_data)))
if otsu:
otsu_min_density = otsu_edge_hu_threshold(data, mask,
average_density, min_threshold, max_threshold,
file_prefix=file_prefix, log=log)
logging.info('{}, {}, otsu min_density = {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_prefix, otsu_min_density))
if otsu_min_density != min_density:
if otsu_min_density < min_density:
# 重新获取结节数据
nodule_data = filter_data(data[mask == 1])
min_density = otsu_min_density
nodule_data = nodule_data[nodule_data >= min_density]
nodule_data = nodule_data[nodule_data <= max_density]
if len(nodule_data) == 0:
return hu_max, hu_max, hu_max
average_density = int(np.round(np.mean(nodule_data)))
if log:
logging.info('{}, {}, average_density = {} '
'min_density = {} max_density = {} '
'min_threshold = {} max_threshold = {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_prefix, average_density,
min_density, max_density,
min_threshold, max_threshold))
return average_density, min_density, max_density
def otsu_edge_hu_threshold(data, mask,
average_density, min_threshold, max_threshold,
hu_min=-1000, hu_max=2000, file_prefix='', log=False):
new_masks = np.zeros((4, mask.shape[0], mask.shape[1], mask.shape[2]), np.uint8)
for z in range(len(mask)):
if np.all(mask[z] == 0):
continue
_, contours, _ = cv2.findContours(mask[z], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) > 1:
contours = sorted(contours, key=cv2.contourArea, reverse=True)
for contour_index, contour in enumerate(contours):
if contour_index > 0:
break
(x, y), radius = cv2.minEnclosingCircle(contour)
(x, y, radius) = np.int0((x, y, radius))
if radius >= 15:
increase_length = 3
elif radius >= 12:
increase_length = 2
else:
increase_length = 1
for new_mask_index, new_mask in enumerate(new_masks):
increase_points = calculate_increase_points(new_mask_index + increase_length)
around_points = batch_get_around_points(mask[z], contour, increase_points)
new_mask[z][around_points[:, 0], around_points[:, 1]] = 1
# 原始结节的点
original_mask = mask.copy()
original_mask[data <= hu_min] = 0
original_mask[data >= hu_max] = 0
original_mask[data >= average_density] = 0
nodule_datas = []
for new_mask in new_masks:
# 结节边缘与周围的点
new_mask[data <= hu_min] = 0
new_mask[data >= hu_max] = 0
new_mask[data >= average_density] = 0
# 结节边缘的点
is_nodule_mask = np.bitwise_and(new_mask, original_mask)
# 结节周围的点
no_nodule_mask = new_mask - is_nodule_mask
no_nodule_mask[data >= min_threshold] = 0
no_nodule_mask[data >= max_threshold] = 0
# 填补结节边缘与结节周围使它们的点数量相同,形成两个波峰的效果
is_nodule_data = data[is_nodule_mask == 1]
no_nodule_data = data[no_nodule_mask == 1]
is_nodule_point_count = len(is_nodule_data)
no_nodule_point_count = len(no_nodule_data)
if is_nodule_point_count < no_nodule_point_count:
no_nodule_data = no_nodule_data[0:is_nodule_point_count]
no_nodule_point_count = len(no_nodule_data)
if is_nodule_point_count <= 0 or no_nodule_point_count <= 0:
continue
is_nodule_density = int(np.round(np.mean(is_nodule_data)))
no_nodule_density = int(np.round(np.mean(no_nodule_data)))
if is_nodule_point_count > no_nodule_point_count:
add_data = [no_nodule_density] * (is_nodule_point_count - no_nodule_point_count)
no_nodule_data = np.concatenate((no_nodule_data, add_data))
elif is_nodule_point_count < no_nodule_point_count:
add_data = [is_nodule_density] * (no_nodule_point_count - is_nodule_point_count)
is_nodule_data = np.concatenate((is_nodule_data, add_data))
add_count = max(int(0.1 * (is_nodule_point_count + no_nodule_point_count)), 20)
add_data = [no_nodule_density] * add_count
no_nodule_data = np.concatenate((no_nodule_data, add_data))
add_data = [is_nodule_density] * add_count
is_nodule_data = np.concatenate((is_nodule_data, add_data))
nodule_data = np.concatenate((is_nodule_data, no_nodule_data))
nodule_datas.append(nodule_data)
if len(nodule_datas) == 0:
return min_threshold
hu_thresholds = []
for n in range(1, len(nodule_datas) + 1):
for tup in itertools.combinations(list(range(len(nodule_datas))), n):
nodule_data = np.empty(0, np.float32)
for index in tup:
nodule_data = np.concatenate((nodule_data, nodule_datas[index]))
nodule_data = nodule_data + np.abs(hu_min)
nodule_data = nodule_data.astype(np.uint)
hu_threshold = mahotas.thresholding.otsu(nodule_data) - np.abs(hu_min)
hu_thresholds.append(hu_threshold)
hu_thresholds = np.sort(np.array(hu_thresholds))
min_density = get_hu_threshold(hu_thresholds, min_threshold, max_threshold)
if log:
logging.info('{}, {}, min_density = {} {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_prefix, min_density, hu_thresholds))
return min_density
def get_hu_threshold(hu_thresholds, min_threshold, max_threshold):
hu_thresholds = hu_thresholds.copy()
hu_thresholds = hu_thresholds[hu_thresholds >= (min_threshold - 50)]
if len(hu_thresholds) == 0:
return min_threshold
hu_thresholds = hu_thresholds[hu_thresholds <= max_threshold]
if len(hu_thresholds) == 0:
return max_threshold
for ratio in [5, 4, 3]:
hu_thresholds_exclude = exclude_exception_hu(hu_thresholds, ratio=ratio)
if len(hu_thresholds_exclude) > 0:
hu_thresholds = hu_thresholds_exclude
else:
break
values, counts = np.unique(hu_thresholds, return_counts=True)
max_count = max(counts)
if max_count > 1:
# hu_threshold = values[np.argmax(counts)]
hu_min_threshold = np.min(values[counts == max_count])
hu_max_threshold = np.max(values[counts == max_count])
hu_thresholds = hu_thresholds[hu_thresholds >= hu_min_threshold]
hu_thresholds = hu_thresholds[hu_thresholds <= hu_max_threshold]
hu_threshold = int(np.round(np.mean(hu_thresholds)))
else:
hu_threshold = int(np.round(np.mean(hu_thresholds)))
return hu_threshold
def exclude_exception_hu(hu_thresholds, ratio=2):
hu_thresholds = hu_thresholds.copy()
hu_average_value = np.mean(hu_thresholds)
hu_average_space = np.round((np.max(hu_thresholds) - np.min(hu_thresholds)) / (len(hu_thresholds) + 1), 5)
hu_max_space = max(ratio * hu_average_space, 5)
hu_min_threshold = hu_thresholds[0]
hu_max_threshold = hu_thresholds[-1]
values, counts = np.unique(hu_thresholds, return_counts=True)
hu_spaces = (values[1:] - values[:-1]) / counts[1:]
for i in range(len(hu_spaces)):
if hu_spaces[i] >= hu_max_space:
if values[i + 1] < hu_average_value:
hu_min_threshold = values[i + 1]
elif hu_average_value < values[i] < hu_max_threshold:
hu_max_threshold = values[i]
hu_thresholds = hu_thresholds[hu_thresholds >= hu_min_threshold]
hu_thresholds = hu_thresholds[hu_thresholds <= hu_max_threshold]
return hu_thresholds
def set_nodule_outside_mask(data, mask, check_density=True, average_density=None, ratio=0.5, min_length=5):
"""
设置结节外部肺部组织,结节mask为1, 外部肺部组织mask为99
"""
if average_density is None:
# 获取结节数据
nodule_data = data[mask == 1]
if len(nodule_data) == 0:
return
# 计算平均密度
average_density = int(np.round(np.mean(nodule_data)))
for z in range(len(mask)):
if np.sum(mask[z] == 1) == 0:
continue
_, contours, _ = cv2.findContours(mask[z], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) > 1:
contours = sorted(contours, key=cv2.contourArea, reverse=True)
for contour_index, contour in enumerate(contours):
if len(contours) > 1:
# 设置contour获取包含的所有点
check_mask = np.zeros(mask[z].shape, np.uint8)
cv2.drawContours(check_mask, [contour], -1, 1, -1)
else:
check_mask = mask[z]
nodule_point_num = np.sum(check_mask == 1)
(_, _), radius = cv2.minEnclosingCircle(contour)
for increase_length in range(max(int(np.ceil(ratio * radius)), min_length)):
increase_points = calculate_increase_points(increase_length)
around_points = batch_get_around_points(check_mask, contour, increase_points, select_value=0)
around_points1 = [point for point in around_points
if check_mask[point[0], point[1]] == 0 and
(not check_density or data[z][point[0], point[1]] < average_density)]
if len(around_points1) > 0:
around_points1 = np.array(around_points1)
mask[z][around_points1[:, 0], around_points1[:, 1]] = 99
check_mask[around_points1[:, 0], around_points1[:, 1]] = 99
outside_point_num = np.sum(check_mask == 99)
if (outside_point_num > nodule_point_num) or \
(increase_length > min_length and outside_point_num > 0.7 * nodule_point_num):
break
def get_nodule_peak(nodule_data, kernel, hu_range=25):
"""
获取结节波峰HU值, HU最小值,HU最大值
"""
values, counts = np.unique(nodule_data, return_counts=True)
max_values = values[counts == max(counts)]
if len(max_values) > 1:
max_total_hu_value = []
for max_value in max_values:
temp_data = np.logical_and(nodule_data >= max_value - hu_range, nodule_data <= max_value + hu_range)
max_total_hu_value.append(np.sum(temp_data))
max_index = np.argmax(np.array(max_total_hu_value))
peak_hu_value = max_values[max_index]
else:
peak_index = np.argmax(counts)
peak_hu_value = values[peak_index]
min_hu_value = peak_hu_value
max_hu_value = peak_hu_value
counts = ndimage.convolve(counts, kernel, mode='nearest')
counts = counts.astype(np.float32)
peak_index = np.argmax(values == peak_hu_value)
for i in range(peak_index - 1, -1, -1):
if counts[i] >= counts[peak_index] / 2:
min_hu_value = values[i]
else:
break
for i in range(peak_index + 1, len(values), 1):
if counts[i] >= counts[peak_index] / 2:
max_hu_value = values[i]
else:
break
return peak_hu_value, min_hu_value, max_hu_value
def get_nodule_peak_info(data, mask):
"""
获取结节波峰相关信息(波峰HU值, 波峰HU数量, HU最小值,HU最大值,占比)
"""
nodule_data = data[mask == 1]
total_point_num = len(nodule_data)
if total_point_num == 0:
return None
# 过滤无效的hu值
nodule_data = filter_data(nodule_data)
if len(nodule_data) == 0:
nodule_data = data[mask == 1]
kernel = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
kernel = kernel / np.sum(kernel)
# 计算平均密度
average_density = int(np.round(np.mean(nodule_data)))
peak_hu_value, min_hu_value, max_hu_value = get_nodule_peak(nodule_data, kernel)
if average_density - peak_hu_value > 100 and \
np.abs(2 * peak_hu_value - min_hu_value - max_hu_value) > 50:
nodule_data_copy = nodule_data.copy()
nodule_data_copy = nodule_data_copy[nodule_data_copy >= average_density]
peak_hu_value2, min_hu_value2, max_hu_value2 = get_nodule_peak(nodule_data_copy, kernel)
nodule_data_copy = nodule_data.copy()
nodule_data_copy = nodule_data_copy[nodule_data_copy >= peak_hu_value]
nodule_data_copy = nodule_data_copy[nodule_data_copy <= peak_hu_value2]
values, counts = np.unique(nodule_data_copy, return_counts=True)
counts = counts.astype(np.float32)
counts = ndimage.convolve(counts, kernel, mode='nearest')
min_values = values[counts == min(counts)]
trough_hu_value = min_values[-1]
nodule_data_copy = nodule_data.copy()
nodule_data_copy = nodule_data_copy[nodule_data_copy >= trough_hu_value]
peak_hu_value2, min_hu_value2, max_hu_value2 = get_nodule_peak(nodule_data_copy, kernel)
if peak_hu_value2 - min_hu_value2 > 50:
peak_hu_value, min_hu_value, max_hu_value = peak_hu_value2, min_hu_value2, max_hu_value2
nodule_data_copy = nodule_data.copy()
nodule_data_copy = nodule_data_copy[nodule_data_copy >= min_hu_value]
nodule_data_copy = nodule_data_copy[nodule_data_copy <= max_hu_value]
current_point_num = len(nodule_data_copy)
ratio = np.round(current_point_num / total_point_num, 5)
peak_hu_count = np.sum(nodule_data == peak_hu_value)
peak_info = (peak_hu_value, peak_hu_count, min_hu_value, max_hu_value, ratio)
return peak_info
# -*- coding: utf-8 -*-
import random
import itertools
import numpy as np
def random_permute_key():
"""
Generates and randomly selects a permute key.
"""
return random.choice(list(generate_general_permute_keys()))
def generate_permute_keys():
"""
This function returns a set of "keys" that represent the 48 unique rotations &
reflections of a 3D matrix.
Each item of the set is a tuple:
((rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose)
"""
return set(itertools.product(
itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))
# def generate_general_indexes():
# indexes = [[0, 1, 2, 3, 4, 5, 6, 7],
# [1, 2, 3, 0, 5, 6, 7, 4],
# [2, 3, 0, 1, 6, 7, 4, 5],
# [3, 0, 1, 2, 7, 4, 5, 6],
# [4, 5, 6, 7, 0, 1, 2, 3],
# [5, 6, 7, 4, 1, 2, 3, 0],
# [6, 7, 4, 5, 2, 3, 0, 1],
# [7, 4, 5, 6, 3, 0, 1, 2]]
# return indexes
def generate_general_indexes():
indexes = [[0, 2],
[1, 3],
[2, 0],
[3, 1],
[4, 6],
[5, 7],
[6, 4],
[7, 5]]
return indexes
def generate_general_permute_keys():
rotate_y = 0
transpose = 0
keys = [((rotate_y, 0), 0, 0, 0, transpose),
((rotate_y, 1), 0, 0, 0, transpose),
((rotate_y, 0), 0, 1, 1, transpose),
((rotate_y, 1), 0, 1, 1, transpose),
((rotate_y, 0), 1, 0, 0, transpose),
((rotate_y, 1), 1, 0, 0, transpose),
((rotate_y, 0), 1, 1, 1, transpose),
((rotate_y, 1), 1, 1, 1, transpose)]
return keys
def permute_data(data, key):
if data is None or key == ((0, 0), 0, 0, 0, 0):
return data
data = np.copy(data)
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
if rotate_y != 0:
data = np.rot90(data, rotate_y, axes=(0, 2))
if rotate_z != 0:
data = np.rot90(data, rotate_z, axes=(1, 2))
if flip_z:
data = np.flip(data, axis=0)
if flip_y:
data = np.flip(data, axis=1)
if flip_x:
data = np.flip(data, axis=2)
if transpose:
add_axes = [i for i in range(3, len(data.shape))]
data = np.transpose(data, axes=[2, 1, 0] + add_axes)
return data
def reverse_permute_data(data, key):
if data is None or key == ((0, 0), 0, 0, 0, 0):
return data
data = np.copy(data)
key = reverse_permute_key(key)
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
if transpose:
add_axes = [i for i in range(3, len(data.shape))]
data = np.transpose(data, axes=[2, 1, 0] + add_axes)
if flip_x:
data = np.flip(data, axis=2)
if flip_y:
data = np.flip(data, axis=1)
if flip_z:
data = np.flip(data, axis=0)
if rotate_z != 0:
data = np.rot90(data, rotate_z, axes=(1, 2))
if rotate_y != 0:
data = np.rot90(data, rotate_y, axes=(0, 2))
return data
def reverse_permute_key(key):
rotation = tuple([-rotate for rotate in key[0]])
return rotation, key[1], key[2], key[3], key[4]
def augment(data, targets=None, mask=None, weight=None):
key = random_permute_key()
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
rotate_y = 0
transpose = 0
key = (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose
aug_data = permute_data(data, key)
aug_targets = targets
if targets is not None:
for i in range(len(aug_targets)):
aug_targets[i] = permute_data(aug_targets[i], key)
transform_target(aug_targets[i], key)
aug_mask = permute_data(mask, key)
aug_weight = permute_data(weight, key)
return aug_data, aug_targets, aug_mask, aug_weight
def transform_target(target, key):
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
if rotate_y:
z = target[..., 1].copy()
x = target[..., 3].copy()
target[..., 1] = -1.0 * x
target[..., 3] = z
if rotate_z:
y = target[..., 2].copy()
x = target[..., 3].copy()
target[..., 2] = -1.0 * x
target[..., 3] = y
if flip_z:
target[..., 1] = -1.0 * target[..., 1]
if flip_y:
target[..., 2] = -1.0 * target[..., 2]
if flip_x:
target[..., 3] = -1.0 * target[..., 3]
if transpose:
z = target[..., 1].copy()
x = target[..., 3].copy()
target[..., 1] = x
target[..., 3] = z
def data_test():
data = np.arange(40).reshape((2, 2, 2, 1, 5))
print('data', data)
history = []
for key in list(generate_general_permute_keys()):
print('key', key)
new_data = permute_data(data, key)
# print('new_data', new_data)
for old_data in history:
if np.all(old_data == new_data):
print('new_data is exist!!!')
break
history.append(new_data)
reverse_data = reverse_permute_data(new_data, key)
if not np.all(data == reverse_data):
print('error:', reverse_data)
def data_test2():
data = np.arange(60).reshape((3, 2, 2, 1, 5))
for i in range(10000):
key = random_permute_key()
print('key', key)
new_data = permute_data(data, key)
if new_data.shape != data.shape:
print('error:', new_data.shape)
def data_test3():
data = np.arange(40).reshape((2, 2, 2, 1, 5))
print('data', data)
key = random_permute_key()
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
rotate_y = 0
rotate_z = 1
flip_z = 0
flip_y = 0
flip_x = 0
transpose = 0
key = (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose
new_data = permute_data(data, key)
transform_target(new_data, key)
print('new_data', new_data)
def data_test4():
data = np.arange(27).reshape((3, 3, 3))
print('data', data)
history = []
for key in list(generate_general_permute_keys()):
print('key', key)
new_data = permute_data(data, key)
print('new_data', new_data)
for old_data in history:
if np.all(old_data == new_data):
print('new_data is exist!!!')
break
history.append(new_data)
reverse_data = reverse_permute_data(new_data, key)
if not np.all(data == reverse_data):
print('error:', reverse_data)
if __name__ == '__main__':
data_test2()
# -*- coding: utf-8 -*-
import numpy as np
from collections import namedtuple
NoduleBox = namedtuple('NoduleBox',
['z', 'y', 'x', 'diameter', 'is_pos', 'probability'])
def nodule_raw2standard(nodule_boxes, raw_spacing, standard_spacing, start=(0, 0, 0)):
new_nodule_boxes = []
for nodule_box in nodule_boxes:
z, y, x, diameter, is_pos, probability = nodule_box
z = z * raw_spacing[0] / standard_spacing[0] - start[0]
y = y * raw_spacing[1] / standard_spacing[1] - start[1]
x = x * raw_spacing[2] / standard_spacing[2] - start[2]
new_nodule_boxes.append(
NoduleBox(z, y, x, diameter, is_pos, probability))
return new_nodule_boxes
def nodule_standard2raw(nodule_boxes, standard_spacing, raw_spacing, start=(0, 0, 0)):
new_nodule_boxes = []
for nodule_box in nodule_boxes:
z, y, x, diameter, is_pos, probability = nodule_box
z = (z + start[0]) * standard_spacing[0] / raw_spacing[0]
y = (y + start[1]) * standard_spacing[1] / raw_spacing[1]
x = (x + start[2]) * standard_spacing[2] / raw_spacing[2]
new_nodule_boxes.append(
NoduleBox(z, y, x, diameter, is_pos, probability))
return new_nodule_boxes
def iou2d(box_a, box_b):
box_a = np.asarray(box_a)
box_b = np.asarray(box_b)
a_start, a_end = box_a[0:2] - box_a[2:4] / 2, box_a[0:2] + box_a[2:4] / 2
b_start, b_end = box_b[0:2] - box_b[2:4] / 2, box_b[0:2] + box_b[2:4] / 2
y_overlap = max(0, min(a_end[0], b_end[0]) - max(a_start[0], b_start[0]))
x_overlap = max(0, min(a_end[1], b_end[1]) - max(a_start[1], b_start[1]))
intersection = y_overlap * x_overlap
union = box_a[2] * box_a[3] + box_b[2] * box_b[3] - intersection
return 1.0 * intersection / union
def iou3d(box_a, box_b, spacing):
# 半径需要在z, y, x各轴上换算成像素值
r_a = 0.5 * box_a.diameter / spacing
r_b = 0.5 * box_b.diameter / spacing
# starting index in each dimension
z_a_s, y_a_s, x_a_s = box_a.z - r_a[0], box_a.y - r_a[1], box_a.x - r_a[2]
z_b_s, y_b_s, x_b_s = box_b.z - r_b[0], box_b.y - r_b[1], box_b.x - r_b[2]
# ending index in each dimension
z_a_e, y_a_e, x_a_e = box_a.z + r_a[0], box_a.y + r_a[1], box_a.x + r_a[2]
z_b_e, y_b_e, x_b_e = box_b.z + r_b[0], box_b.y + r_b[1], box_b.x + r_b[2]
z_overlap = max(0, min(z_a_e, z_b_e) - max(z_a_s, z_b_s))
y_overlap = max(0, min(y_a_e, y_b_e) - max(y_a_s, y_b_s))
x_overlap = max(0, min(x_a_e, x_b_e) - max(x_a_s, x_b_s))
intersection = z_overlap * y_overlap * x_overlap
union = 8 * r_a[0] * r_a[1] * r_a[2] + 8 * r_b[0] * r_b[1] * r_b[2] - intersection
return 1.0 * intersection / union
def nms(nodule_boxes, spacing, iou_thred):
if nodule_boxes is None or len(nodule_boxes) == 0:
return nodule_boxes
nodule_boxes_nms = []
for nodule_box in nodule_boxes:
overlap = False
for box in nodule_boxes_nms:
if iou3d(box, nodule_box, spacing) >= iou_thred:
overlap = True
break
if not overlap:
nodule_boxes_nms.append(nodule_box)
return nodule_boxes_nms
# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np
from scipy import ndimage
from skimage import morphology
import torch
import torch.nn.functional as F
edge_enhance_kernel = np.array([[-1, -1, -1],
[-1, 8, -1],
[-1, -1, -1]])
def check_and_makedirs(output_file, is_file=False):
if is_file:
output_dir = os.path.dirname(output_file)
else:
output_dir = output_file
if output_dir is not None and output_dir != '' and not os.path.exists(output_dir):
os.makedirs(output_dir)
def sigmoid(x):
# return 1 / (1 + np.exp(-x))
return 0.5 * (1 + np.tanh(0.5 * x))
def get_crop_start_end(start, end, crop_size):
center = (np.asarray(start) + np.asarray(end)) // 2
crop_start = center - np.asarray(crop_size) // 2
crop_start = np.maximum(crop_start, 0)
return crop_start, crop_start + crop_size
def maxip(input, level_num=3):
level_num = level_num // 2
output = np.zeros(input.shape, input.dtype)
for i in range(len(input)):
length = level_num if i >= level_num else i
output[i] = np.max(input[i-length:i+level_num+1], axis=0)
return output
def minip(input, level_num=3):
level_num = level_num // 2
output = np.zeros(input.shape, input.dtype)
for i in range(len(input)):
length = level_num if i >= level_num else i
output[i] = np.min(input[i-length:i+level_num+1], axis=0)
return output
def downsample_data(data, scale, mode='max'):
"""
下采样
"""
if data is None or scale == (1, 1, 1):
return data
new_data = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0)
if mode == 'max':
new_data = F.max_pool3d(new_data, kernel_size=scale, stride=scale, ceil_mode=True)
else:
new_data = F.avg_pool3d(new_data, kernel_size=scale, stride=scale, ceil_mode=True)
new_data = new_data[0][0].numpy()
return new_data
def upsample_data(data, scale, mode='nearest'):
"""
上采样
The modes available for resizing are: `nearest`, `linear` (3D-only),
`bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`
"""
if data is None or scale == (1, 1, 1):
return data
new_data = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0)
new_data = F.interpolate(new_data, scale_factor=scale, mode=mode)
new_data = new_data[0][0].numpy()
return new_data
def downsample_mask(mask, scale, prob_thred=0.5):
if mask is None or scale == (1, 1, 1):
return mask
new_mask = downsample_data(mask, scale, mode='avg')
new_mask = new_mask > prob_thred
new_mask = new_mask.astype(np.uint8)
return new_mask
def upsample_mask(mask, scale):
if mask is None or scale == (1, 1, 1):
return mask
new_mask = upsample_data(mask, scale)
new_mask = new_mask.astype(np.uint8)
return new_mask
def resample_data(data, spacing, new_spacing, mode='trilinear'):
if data is None or np.all(spacing == new_spacing):
return data, spacing
new_shape = np.round(data.shape * spacing / new_spacing)
new_spacing = spacing * data.shape / new_shape
scale = new_shape / data.shape
new_data = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0)
new_data = F.interpolate(new_data, scale_factor=scale, mode=mode, align_corners=False)
new_data = new_data[0][0].numpy()
return new_data, new_spacing
def resample_mask(mask, size, mode='trilinear', prob_thred=0.5):
if mask is None or np.all(mask.shape == size):
return mask
new_mask = torch.from_numpy(mask).float().unsqueeze(0).unsqueeze(0)
new_mask = F.interpolate(new_mask, size=size, mode=mode, align_corners=False)
new_mask = new_mask[0][0].numpy()
new_mask = new_mask > prob_thred
new_mask = new_mask.astype(np.uint8)
return new_mask
def extend_mask(mask, kernel_size=(1, 7, 7), padding=(0, 3, 3), min_value=0.01, max_value=1.0):
new_mask = torch.from_numpy(mask).float().unsqueeze(0).unsqueeze(0)
new_mask = F.avg_pool3d(new_mask, kernel_size=kernel_size, padding=padding, stride=1)
new_mask = new_mask.ge(min_value) * new_mask.le(max_value)
new_mask = new_mask[0][0].numpy()
return new_mask
def clip_data(data, min_value=-1000, max_value=2000, nan_value=-2000):
new_data = np.array(data)
new_data[np.isnan(new_data)] = nan_value
new_data = np.clip(new_data, min_value, max_value)
return new_data
def filter_data(data, min_value=-1000, max_value=2000, nan_value=-2000):
new_data = np.array(data)
new_data[np.isnan(new_data)] = nan_value
new_data = new_data[new_data >= min_value]
new_data = new_data[new_data <= max_value]
return new_data
def normalize_cls(data, min_value=-1000, max_value=600):
new_data = data.clone()
new_data[new_data < min_value] = min_value
new_data[new_data > max_value] = max_value
# normalize to [-1, 1]
new_data = 2.0 * (new_data - min_value) / (max_value - min_value) - 1
return new_data
def normalize_quantify_data(quantify_data):
quantify_data[:, 0] = normalize_cls(quantify_data[:, 0], 0, 5000)
quantify_data[:, 2] = normalize_cls(quantify_data[:, 2], 0, 1600)
quantify_data[:, 3] = normalize_cls(quantify_data[:, 3], 0, 10000)
for i in range(4, quantify_data.shape[1]):
quantify_data[:, i] = normalize_cls(quantify_data[:, i], 0, 1600)
return quantify_data
def get_quantify_data(data, mask, spacing):
nodule_data = filter_data(data[mask == 1])
mask_point_count = len(nodule_data)
if mask_point_count > 0:
# 一个点的体积
single_volume = spacing[0] * spacing[1] * spacing[2]
# 总体积
volume = np.round(mask_point_count * single_volume, 3)
# 平均密度, 加1000转换成密度
average_density = np.mean(nodule_data) + 1000
average_density = np.max(np.round(average_density, 3), 0)
# 比重
specific_gravity = np.round(1000 * average_density / volume, 3)
else:
volume = 0
average_density = 0
specific_gravity = 0
return volume, average_density, specific_gravity
def get_average_diameter(volume):
return np.round(pow((3 * volume / (4 * np.pi)), 1 / 3) * 2, 3)
def get_3d_curve_data(data, mask, spacing, curve_step):
curve_data = np.zeros(curve_step, np.float32)
if np.sum(mask == 1) > 0:
# 一个点的体积
single_volume = spacing[0] * spacing[1] * spacing[2]
# 加1000转换成密度
nodule_data = data[mask == 1] + 1000
nodule_data = filter_data(nodule_data, min_value=0, max_value=curve_step - 1)
nodule_data = np.int0(nodule_data)
values, counts = np.unique(nodule_data, return_counts=True)
volumes = counts * single_volume
for value, volume in zip(values, volumes):
curve_data[value] = volume
return curve_data
def get_3d_volume_percentage(data, mask, max_value=600):
volume_percentage_list = []
if np.sum(mask == 1) > 0:
# 加1000转换成密度
nodule_data = data[mask == 1] + 1000
nodule_data = filter_data(nodule_data, min_value=0, max_value=max_value + 1000 - 1)
nodule_data = np.sort(nodule_data)
total_length = len(nodule_data)
for i in range(10, 4, -1):
length = max(int(0.1 * i * total_length) - 1, 0)
current_hu = nodule_data[length]
current_data = nodule_data[nodule_data <= current_hu]
volume_percentage_list.append(current_hu)
volume_percentage_list.append(np.round(np.mean(current_data), 3))
volume_percentage_list.append(np.round(current_data.std(), 3))
else:
for i in range(10, 4, -1):
volume_percentage_list.append(0)
volume_percentage_list.append(0)
volume_percentage_list.append(0)
return volume_percentage_list
def transform_input_data(data, n_channels):
if n_channels == 1:
data = data[np.newaxis]
elif n_channels == 4:
data_edge_enhance = np.zeros(data.shape, np.float32)
for z in range(data.shape[0]):
data_edge_enhance[z] = ndimage.convolve(data[z], edge_enhance_kernel, mode='nearest')
data1 = data
data2 = data
# 边缘增强
data3 = data + 0.5 * data_edge_enhance
# 边缘增强
data4 = data + 1.0 * data_edge_enhance
# 最大密度投影
# data5 = maxip(data)
# 最小密度投影
# data6 = minip(data)
data = np.array([data1, data2, data3, data4])
return data
def get_crop_data(data, crop_start, crop_size, mode='constant', constant_values=-1000):
"""
获取切割块数据
"""
if data is None:
return data
data_shape = np.array(data.shape)
crop_data = data[
max(crop_start[0], 0):min(crop_start[0] + crop_size[0], data_shape[0]),
max(crop_start[1], 0):min(crop_start[1] + crop_size[1], data_shape[1]),
max(crop_start[2], 0):min(crop_start[2] + crop_size[2], data_shape[2])]
pad = np.zeros((3, 2), np.int)
pad[:, 0] = np.maximum(-crop_start, 0)
pad[:, 1] = np.maximum(crop_start + crop_size - data_shape, 0)
if not np.all(pad == 0):
if mode == 'edge':
crop_data = np.pad(crop_data, pad, mode=mode)
else:
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
return crop_data
def get_anchor_coords(target_size, stride):
"""
获取anchor在z, y, x轴上的中心点坐标
"""
offset = 0.5
z_center_anchors = np.arange(offset, offset + target_size[0]) * stride[0]
y_center_anchors = np.arange(offset, offset + target_size[1]) * stride[1]
x_center_anchors = np.arange(offset, offset + target_size[2]) * stride[2]
return z_center_anchors, y_center_anchors, x_center_anchors
def convex_hull_3d(mask):
for z in range(len(mask)):
if np.all(mask[z] == 0):
continue
mask[z] = convex_hull_2d(mask[z])
return mask
def convex_hull_2d(mask):
mask = morphology.convex_hull_object(mask, neighbors=8)
mask = mask.astype(np.uint8)
return mask
def continuous_2d(input_data):
output_data = input_data.copy()
for j in range(len(output_data)):
if j % 2 == 1:
output_data[j] = output_data[j][::-1]
return output_data
def var_2d(input_data, kernel_size=3, min_value=-100):
output_data = np.zeros(input_data.shape, np.float32)
output_data[...] = np.inf
kernel_size = kernel_size // 2
for i in range(kernel_size, output_data.shape[0] - kernel_size):
kernel_i = kernel_size if i >= kernel_size else i
for j in range(kernel_size, output_data.shape[1] - kernel_size):
if input_data[i, j] < min_value:
continue
kernel_j = kernel_size if j >= kernel_size else j
region_data = input_data[i - kernel_i:i + kernel_size + 1, j - kernel_j:j + kernel_size + 1]
output_data[i, j] = region_data.var()
return output_data
def std_2d(input_data, kernel_size=3, min_value=-100):
output_data = np.zeros(input_data.shape, np.float32)
output_data[...] = np.inf
kernel_size = kernel_size // 2
for i in range(kernel_size, output_data.shape[0] - kernel_size):
kernel_i = kernel_size if i >= kernel_size else i
for j in range(kernel_size, output_data.shape[1] - kernel_size):
if input_data[i, j] < min_value:
continue
kernel_j = kernel_size if j >= kernel_size else j
region_data = input_data[i - kernel_i:i + kernel_size + 1, j - kernel_j:j + kernel_size + 1]
output_data[i, j] = region_data.std()
return output_data
def get_diameter_pixel_length(diameter, spacing):
return np.int0(np.ceil(diameter / spacing))
def get_center_and_diameter(nodule_mask, spacing):
"""
根据mask获取结节中心点坐标与直径
"""
# 中心点坐标
center = np.zeros(3, np.int)
# max_average_width_height = 0
# for z, mask in enumerate(nodule_mask):
# if np.sum(mask == 1) == 0:
# continue
# _, contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# contours = sorted(contours, key=cv2.contourArea, reverse=True)
# # 最小外接矩形
# rect_min = cv2.minAreaRect(contours[0])
# ((_, _), (width, height), angle) = rect_min
# # 最大面层对应的z坐标
# if max_average_width_height < (width + height) / 2:
# max_average_width_height = (width + height) / 2
# center[0] = z
# 每层mask点数
z_point_count = np.sum(nodule_mask == 1, axis=(1, 2))
# 最大面层
center[0] = np.argmax(z_point_count)
coords = np.asarray(np.where(nodule_mask[center[0]] == 1))
coord_start = coords.min(axis=1)
coord_end = coords.max(axis=1) + 1
# 最大面层的中心点坐标
center[1:3] = (coord_start + coord_end) // 2
# 直径,仅比较y,x轴
diameter = np.round(np.max((coord_end - coord_start) * spacing[1:3]), 5)
return center, diameter
def get_nodule_rect(nodule_box, spacing):
diameter_pixel_length = get_diameter_pixel_length(nodule_box.diameter, spacing)
center = np.array([nodule_box.z, nodule_box.y, nodule_box.x])
rect = np.zeros((3, 2), np.int)
rect[:, 0] = center - diameter_pixel_length // 2
rect[:, 1] = rect[:, 0] + diameter_pixel_length
return rect
def save_pred_npy(npy_path, file_prefix, pred):
check_and_makedirs(npy_path)
np.save(os.path.join(npy_path, file_prefix + '_pred.npy'),
pred.astype(np.uint8))
def load_pred_npy(npy_path, file_prefix):
pred = np.load(os.path.join(npy_path, file_prefix + '_pred.npy'))
return pred
def save_mask_npy(npy_path, file_prefix, mask):
check_and_makedirs(npy_path)
np.save(os.path.join(npy_path, file_prefix + '_mask.npy'),
mask.astype(np.uint8))
def load_mask_npy(npy_path, file_prefix):
mask = np.load(os.path.join(npy_path, file_prefix + '_mask.npy'))
return mask
def print_data_mask(data, mask, filename=None, group=False):
coords = np.asarray(np.where(mask > 0))
if len(coords[0]) > 0:
coord_start = coords.min(axis=1)
coord_end = coords.max(axis=1) + 1
print_mask = mask[coord_start[0]:coord_end[0], coord_start[1]:coord_end[1]]
print_data = data[coord_start[0]:coord_end[0], coord_start[1]:coord_end[1]]
if filename is not None:
with open(filename, 'w') as f:
for temp_data in print_data:
f.write(np.array2string(temp_data).replace('\n', '') + '\n')
for temp_mask in print_mask:
f.write(np.array2string(temp_mask).replace('\n', '') + '\n')
if group:
for temp_data, temp_mask in zip(print_data, print_mask):
temp_group = np.array([str(i) + '(' + str(j) + ')' for i, j in zip(temp_data, temp_mask)])
f.write(np.array2string(temp_group).replace('\n', '').replace('\'', '') + '\n')
else:
for temp_data in print_data:
print(np.array2string(temp_data).replace('\n', '') + '\n')
for temp_mask in print_mask:
print(np.array2string(temp_mask).replace('\n', '') + '\n')
if group:
for temp_data, temp_mask in zip(print_data, print_mask):
temp_group = np.array([str(i) + '(' + str(j) + ')' for i, j in zip(temp_data, temp_mask)])
print(np.array2string(temp_group).replace('\n', '').replace('\'', '') + '\n')
# -*- coding: utf-8 -*-
import os, sys
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.patches as patches
import seaborn as sns
import cv2
from data.data_process_utils.test_data_utils import check_and_makedirs
from data.data_process_utils.test_data_utils import get_crop_start_end, get_diameter_pixel_length, get_nodule_rect
from data.data_process_utils.test_data_utils import continuous_2d, var_2d
def plot_image_and_rect(data, rect=None,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
col = 1
row = max(data.shape[0], 2)
fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
for i in range(0, data.shape[0], stride):
z = z_start + i + 1
ax = plots[i]
ax.set_title('original z=' + str(z))
if rect is not None:
ax.set_title('ground truth z=' + str(z))
if rect[0][0] <= i < rect[0][1]:
rectangle = patches.Rectangle((rect[2][0], rect[1][0]),
rect[2][1] - rect[2][0],
rect[1][1] - rect[1][0],
linewidth=0.5, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def plot_two_image(data1, data2,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
col = 2
row = max(data1.shape[0], 2)
fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
for i in range(0, data1.shape[0], stride):
z = z_start + i + 1
ax = plots[i, 0]
ax.set_title('z=' + str(z))
ax.imshow(data1[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 1]
ax.set_title('z=' + str(z))
ax.imshow(data2[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def plot_image_and_mask(data, mask, truth=None, rect=None,
spacing=None, show_whole_hist=False,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
if np.sum(mask == 1) > 500:
total_nodule_data = data[mask == 1]
min_density = int(np.min(total_nodule_data))
max_density = int(np.max(total_nodule_data))
vmin = max(min_density - 100, -1350)
vmax = min(max_density + 100, 150)
if spacing is not None:
single_volume = spacing[0] * spacing[1] * spacing[2]
total_volume = int(np.sum(mask == 1) * single_volume)
else:
single_volume = 1
total_volume = 0
col = 3
if truth is not None or rect is not None:
col = col + 1
row = max(data.shape[0], 2)
if show_whole_hist:
row = row + 3
fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
for i in range(0, data.shape[0], stride):
z = z_start + i + 1
ax = plots[i, 0]
ax.set_title('original z=' + str(z))
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 1]
ax.set_title('contour z=' + str(z))
if np.any(mask[i] == 1):
if spacing is not None:
max_diameter = 0
_, contours, _ = cv2.findContours(mask[i], cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
for contour in contours:
rect_min = cv2.minAreaRect(contour)
((_, _), (width, height), angle) = rect_min
diameter = np.round(max(width, height) * spacing[1], 1)
if max_diameter < diameter:
max_diameter = diameter
# box = cv2.boxPoints(rect_min)
# rectangle = patches.Rectangle(box[1], width, height, angle,
# linewidth=1, edgecolor='b', facecolor='none')
# ax.add_patch(rectangle)
# (x, y), radius = cv2.minEnclosingCircle(contour)
# (x, y, radius) = np.int0((x, y, radius))
# circle = patches.Circle((x, y), radius, linewidth=1, edgecolor='b', facecolor='none')
# ax.add_patch(circle)
mask_point_count = np.sum(mask[i] == 1)
if mask_point_count > 0:
volume = int(mask_point_count * single_volume)
else:
volume = 0
ax.set_title('contour z=' + str(z) + ' (p' + str(mask_point_count) + '/d' + str(max_diameter) + '/v'
+ str(volume) + '/v' + str(total_volume) + ')')
if np.any(mask[i] >= 1):
ax.contour(mask[i] >= 1, [0.5], colors='g', linewidths=1, alpha=0.75)
if np.any(mask[i] == 1):
ax.contour(mask[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if truth is not None or rect is not None:
ax = plots[i, 2]
ax.set_title('ground truth z=' + str(z))
if truth is not None and np.any(truth[i] == 1):
ax.contour(truth[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
if rect is not None and rect[0][0] <= i < rect[0][1]:
rectangle = patches.Rectangle((rect[2][0], rect[1][0]),
rect[2][1] - rect[2][0],
rect[1][1] - rect[1][0],
linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, col - 1]
nodule_data = data[i][mask[i] == 1]
if len(nodule_data) > 0:
z_average_density = int(np.round(np.mean(nodule_data)))
z_min_density = int(np.min(nodule_data))
z_max_density = int(np.max(nodule_data))
ax.set_title('HU histogram z=' + str(z) + ' (' + str(z_min_density) + '/' + str(z_average_density)
+ '/' + str(z_max_density) + ')')
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
bins = [x for x in range(min(nodule_data), max(nodule_data) + 1, 1)]
sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
kde=False, kde_kws={'bw': .2, 'color': 'blue'},
rug=False, ax=ax)
if show_whole_hist:
z_max = np.argmax(np.sum(mask == 1, axis=(1, 2)))
z = z_start + z_max + 1
new_mask = mask.copy()
new_mask[new_mask == 99] = 1
nodule_data_z_max = data[z_max][mask[z_max] == 1]
nodule_data_z_max_new = data[z_max][new_mask[z_max] == 1]
nodule_data_total = data[mask == 1]
nodule_data_total_new = data[new_mask == 1]
values, counts = np.unique(nodule_data_z_max_new, return_counts=True)
y_max_z_max = np.ceil(max(counts) / 100) * 100
values, counts = np.unique(nodule_data_total_new, return_counts=True)
y_max = np.ceil(max(counts) / 100) * 100
bins = [x for x in range(min(nodule_data_total_new), max(nodule_data_total_new) + 1, 1)]
ax = plots[row - 3, col - 2]
nodule_data = nodule_data_z_max
if len(nodule_data) > 0:
z_average_density = int(np.round(np.mean(nodule_data)))
z_min_density = int(np.min(nodule_data))
z_max_density = int(np.max(nodule_data))
ax.set_title('HU histogram z=' + str(z) + ' (' + str(z_min_density) + '/' + str(z_average_density)
+ '/' + str(z_max_density) + ')')
ax.set_ylim([0, y_max_z_max])
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
kde=False, kde_kws={'bw': .2, 'color': 'blue'},
rug=False, ax=ax)
ax = plots[row - 3, col - 1]
nodule_data = nodule_data_total
if len(nodule_data) > 0:
z_average_density = int(np.round(np.mean(nodule_data)))
z_min_density = int(np.min(nodule_data))
z_max_density = int(np.max(nodule_data))
ax.set_title('HU histogram' + ' (' + str(z_min_density) + '/' + str(z_average_density)
+ '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
ax.set_ylim([0, y_max])
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
kde=False, kde_kws={'bw': .2, 'color': 'blue'},
rug=False, ax=ax)
ax = plots[row - 2, col - 2]
nodule_data = nodule_data_z_max_new
if len(nodule_data) > 0:
z_average_density = int(np.round(np.mean(nodule_data)))
z_min_density = int(np.min(nodule_data))
z_max_density = int(np.max(nodule_data))
ax.set_title('HU histogram z=' + str(z) + ' (' + str(z_min_density) + '/' + str(z_average_density)
+ '/' + str(z_max_density) + ')')
ax.set_ylim([0, y_max_z_max])
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
sns.distplot(nodule_data_z_max_new.flatten(), bins=bins, color='green',
kde=False, kde_kws={'bw': .2, 'color': 'green'},
rug=False, ax=ax)
ax = plots[row - 2, col - 1]
nodule_data = nodule_data_total_new
if len(nodule_data) > 0:
z_average_density = int(np.round(np.mean(nodule_data)))
z_min_density = int(np.min(nodule_data))
z_max_density = int(np.max(nodule_data))
ax.set_title('HU histogram' + ' (' + str(z_min_density) + '/' + str(z_average_density)
+ '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
ax.set_ylim([0, y_max])
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
sns.distplot(nodule_data_total_new.flatten(), bins=bins, color='green',
kde=False, kde_kws={'bw': .2, 'color': 'green'},
rug=False, ax=ax)
ax = plots[row - 1, col - 2]
nodule_data = nodule_data_z_max_new
if len(nodule_data) > 0:
ax.set_title('HU histogram z=' + str(z))
ax.set_ylim([0, y_max_z_max])
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
sns.distplot(nodule_data_z_max_new.flatten(), bins=bins, color='green',
kde=False, kde_kws={'bw': .2, 'color': 'green'},
rug=False, ax=ax)
sns.distplot(nodule_data_z_max.flatten(), bins=bins, color='blue',
kde=False, kde_kws={'bw': .2, 'color': 'blue'},
rug=False, ax=ax)
ax = plots[row - 1, col - 1]
nodule_data = nodule_data_total_new
if len(nodule_data) > 0:
ax.set_title('HU histogram')
ax.set_ylim([0, y_max])
# ax.set_xlabel('Hounsfield Units (HU)')
# ax.set_ylabel('Count')
# ax.hist(nodule_data.flatten(), bins=50, color='blue', alpha=0.75)
sns.set_palette('hls')
sns.distplot(nodule_data_total_new.flatten(), bins=bins, color='green',
kde=False, kde_kws={'bw': .2, 'color': 'green'},
rug=False, ax=ax)
sns.distplot(nodule_data_total.flatten(), bins=bins, color='blue',
kde=False, kde_kws={'bw': .2, 'color': 'blue'},
rug=False, ax=ax)
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def plot_mask(data, mask, spacing=None, show_whole_hist=False, nodule_peak_info=None,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
if np.sum(mask == 1) > 500:
total_nodule_data = data[mask == 1]
min_density = int(np.min(total_nodule_data))
max_density = int(np.max(total_nodule_data))
vmin = max(min_density - 100, -1350)
vmax = min(max_density + 100, 150)
if spacing is not None:
single_volume = spacing[0] * spacing[1] * spacing[2]
total_volume = int(np.sum(mask == 1) * single_volume)
else:
single_volume = 1
total_volume = 0
col = 1
row = 4
fig, plots = plt.subplots(row, col, figsize=(col * 8, row * 5))
fig.subplots_adjust(left=0.5, top=0.5, hspace=1)
if show_whole_hist:
new_mask = mask.copy()
new_mask[new_mask == 99] = 1
nodule_data_total = data[mask == 1]
nodule_data_total_new = data[new_mask == 1]
values, counts = np.unique(nodule_data_total_new, return_counts=True)
y_max = np.ceil(1.3 * max(counts) / 100) * 100
if nodule_peak_info is not None:
y_max_ratio = nodule_peak_info[1] / y_max + 0.1
y_max_height = y_max * y_max_ratio + 20
temp_min_density = max(int(np.max(nodule_data_total_new)), -200)
bins = [x for x in range(-1000, temp_min_density + 1, 1)]
# ax = plots[0]
# nodule_data = nodule_data_total
# if len(nodule_data) > 0:
# z_average_density = int(np.round(np.mean(nodule_data)))
# z_min_density = int(np.min(nodule_data))
# z_max_density = int(np.max(nodule_data))
# ax.set_title('HU histogram' + ' (' + str(z_min_density) + '/' + str(z_average_density)
# + '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
# ax.set_ylim([0, y_max])
# if nodule_peak_info is not None:
# ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
# ax.text(nodule_peak_info[0] - 30, y_max_height,
# str(nodule_peak_info[0]), fontsize=8, color='r')
# # ax.set_xlabel('Hounsfield Units (HU)')
# # ax.set_ylabel('Count')
# # ax.hist(nodule_data.flatten(), bins=bins, color='blue', alpha=0.75)
# sns.set_palette('hls')
# sns.distplot(nodule_data.flatten(), bins=bins, color='blue',
# hist=True, kde=False, kde_kws={'bw': .2, 'color': 'blue', 'shade': True},
# rug=False, norm_hist=False, ax=ax)
#
# ax = plots[1]
# nodule_data = nodule_data_total_new
# if len(nodule_data) > 0:
# ax.set_title('HU histogram')
# ax.set_ylim([0, y_max])
# if nodule_peak_info is not None:
# ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
# ax.text(nodule_peak_info[0] - 30, y_max_height,
# str(nodule_peak_info[0]), fontsize=8, color='r')
# # ax.set_xlabel('Hounsfield Units (HU)')
# # ax.set_ylabel('Count')
# # ax.hist(nodule_data.flatten(), bins=bins, color='blue', alpha=0.75)
# sns.set_palette('hls')
# sns.distplot(nodule_data_total_new.flatten(), bins=bins, color='green',
# hist=True, kde=False, kde_kws={'bw': .2, 'color': 'green', 'shade': True},
# rug=False, norm_hist=False, ax=ax)
# sns.distplot(nodule_data_total.flatten(), bins=bins, color='blue',
# hist=True, kde=False, kde_kws={'bw': .2, 'color': 'blue', 'shade': True},
# rug=False, norm_hist=False, ax=ax)
ax = plots[0]
nodule_data = nodule_data_total
if len(nodule_data) > 0:
z_average_density = int(np.round(np.mean(nodule_data)))
z_min_density = int(np.min(nodule_data))
z_max_density = int(np.max(nodule_data))
ax.set_title('HU Curve' + ' (' + str(z_min_density) + '/' + str(z_average_density)
+ '/' + str(z_max_density) + '/v' + str(total_volume) + ')')
ax.set_ylim([0, y_max])
ax.set_xlim([bins[0], bins[-1]])
if nodule_peak_info is not None:
ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
ax.text(nodule_peak_info[0] - 30, y_max_height,
str(nodule_peak_info[0]), fontsize=8, color='r')
values, counts = np.unique(nodule_data, return_counts=True)
ax.plot(values, counts, color='blue', linewidth=0.5)
ax = plots[1]
nodule_data = nodule_data_total_new
if len(nodule_data) > 0:
ax.set_title('HU Curve')
ax.set_ylim([0, y_max])
ax.set_xlim([bins[0], bins[-1]])
if nodule_peak_info is not None:
ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
ax.text(nodule_peak_info[0] - 30, y_max_height,
str(nodule_peak_info[0]), fontsize=8, color='r')
values, counts = np.unique(nodule_data, return_counts=True)
ax.plot(values, counts, color='green', linewidth=0.5)
values, counts = np.unique(nodule_data_total, return_counts=True)
ax.plot(values, counts, color='blue', linewidth=0.5)
ax = plots[2]
nodule_data = nodule_data_total
if len(nodule_data) > 0:
ax.set_title('HU Curve')
ax.set_ylim([0, y_max])
ax.set_xlim([bins[0], bins[-1]])
if nodule_peak_info is not None:
ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
ax.text(nodule_peak_info[0] - 30, y_max_height,
str(nodule_peak_info[0]), fontsize=8, color='r')
if nodule_peak_info[3] - nodule_peak_info[2] > 50:
ax.axvline(nodule_peak_info[2], color='gray', ls='--', lw=0.5)
ax.axvline(nodule_peak_info[3], color='gray', ls='--', lw=0.5)
ax.text(int((nodule_peak_info[2] + nodule_peak_info[3]) / 2) - 20, 10,
str(int(np.round(nodule_peak_info[4] * 100))) + '%', fontsize=8, color='gray')
values, counts = np.unique(nodule_data, return_counts=True)
ax.plot(values, counts, color='blue', linewidth=0.5)
ax = plots[3]
nodule_data = nodule_data_total_new
if len(nodule_data) > 0:
ax.set_title('HU Curve')
ax.set_ylim([0, y_max])
ax.set_xlim([bins[0], bins[-1]])
if nodule_peak_info is not None:
ax.axvline(nodule_peak_info[0], ymax=y_max_ratio, color='red', ls='-', lw=0.5)
ax.text(nodule_peak_info[0] - 30, y_max_height,
str(nodule_peak_info[0]), fontsize=8, color='r')
if nodule_peak_info[3] - nodule_peak_info[2] > 50:
ax.axvline(nodule_peak_info[2], color='gray', ls='--', lw=0.5)
ax.axvline(nodule_peak_info[3], color='gray', ls='--', lw=0.5)
ax.text(int((nodule_peak_info[2] + nodule_peak_info[3]) / 2) - 20, 10,
str(int(np.round(nodule_peak_info[4] * 100))) + '%', fontsize=8, color='gray')
values, counts = np.unique(nodule_data, return_counts=True)
ax.plot(values, counts, color='green', linewidth=0.5)
values, counts = np.unique(nodule_data_total, return_counts=True)
ax.plot(values, counts, color='blue', linewidth=0.5)
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def plot_hu_curve(data,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
col = 2
row = max(data.shape[0], 2)
fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5), gridspec_kw={'width_ratios': [1, 8]})
for i in range(0, data.shape[0], stride):
z = z_start + i + 1
ax = plots[i, 0]
ax.set_title('original z=' + str(z))
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, col-1]
nodule_data = continuous_2d(data[i])
nodule_data = nodule_data.flatten()
nodule_data = nodule_data[nodule_data > -200]
if len(nodule_data) > 0:
ax.set_title('HU Curve')
ax.set_xlim([0, len(nodule_data)])
x_values = [x for x in range(len(nodule_data))]
ax.plot(x_values, nodule_data, color='blue', linewidth=0.5)
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def plot_hu_curve2(data,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
col = 12
row = max(data.shape[0], 2)
fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5),
gridspec_kw={'width_ratios': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8]})
for i in range(0, data.shape[0], stride):
z = z_start + i + 1
ax = plots[i, 0]
ax.set_title('original z=' + str(z))
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 1]
kernel_size = 3
var = 9
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 2]
kernel_size = 3
var = 18
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 3]
kernel_size = 5
var = 25
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 4]
kernel_size = 5
var = 50
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 5]
kernel_size = 7
var = 49
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 6]
kernel_size = 7
var = 98
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 7]
kernel_size = 9
var = 81
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 8]
kernel_size = 9
var = 162
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 9]
kernel_size = 11
var = 121
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, 10]
kernel_size = 11
var = 242
ax.set_title('z=' + str(z) + ' k=' + str(kernel_size) + ' var<=' + str(var))
var_data = var_2d(data[i], kernel_size=kernel_size)
change_coords = np.asarray(np.where(var_data <= var))
ax.scatter(change_coords[1, :], change_coords[0, :], color='r', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
ax = plots[i, col-1]
nodule_data = continuous_2d(data[i])
nodule_data = nodule_data.flatten()
nodule_data = nodule_data[nodule_data > -200]
if len(nodule_data) > 0:
ax.set_title('HU Curve')
ax.set_xlim([0, len(nodule_data)])
x_values = [x for x in range(len(nodule_data))]
ax.plot(x_values, nodule_data, color='blue', linewidth=0.5)
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def plot_hu_3d(data):
fig = plt.figure()
ax = Axes3D(fig)
xpos, ypos = np.meshgrid(np.linspace(0, data.shape[1] - 1, data.shape[1]),
np.linspace(0, data.shape[0] - 1, data.shape[0]))
ax.plot_surface(xpos, ypos, data, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.contour(xpos, ypos, data, zdir='z', cmap=plt.get_cmap('rainbow'))
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('hu')
plt.show()
def plot_hu_hist(nodule_data):
plt.hist(nodule_data.flatten(), bins=[x for x in range(-1000, -200 + 1, 1)], color='c', alpha=0.75)
plt.xlabel('Hounsfield Units (HU)')
plt.ylabel('Count')
plt.show()
def plot_image_multi_mask(data, mask, mask1=None, mask2=None, mask3=None,
spacing=None,
z_start=0, stride=1, image_title=None,
cmap=plt.cm.gray, show=False, save=False, file_path=None):
vmin = -1350
vmax = 150
if np.sum(mask == 1) > 500:
total_nodule_data = data[mask == 1]
min_density = int(np.min(total_nodule_data))
max_density = int(np.max(total_nodule_data))
vmin = max(min_density - 100, -1350)
vmax = min(max_density + 100, 150)
if spacing is not None:
single_volume = spacing[0] * spacing[1] * spacing[2]
else:
single_volume = 1
col = 1
if mask1 is not None:
col = col + 1
if mask2 is not None:
col = col + 1
if mask3 is not None:
col = col + 1
row = max(data.shape[0], 2)
fig, plots = plt.subplots(row, col, figsize=(col * 5, row * 5))
for i in range(0, data.shape[0], stride):
z = z_start + i + 1
ax = plots[i, 0]
mask_point_count = np.sum(mask[i] == 1)
ax.set_title('original z=' + str(z) + ' (p' + str(mask_point_count) + ')')
if spacing is not None:
volume = int(mask_point_count * single_volume)
total_volume = int(np.sum(mask == 1) * single_volume)
ax.set_title('original z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
+ str(volume) + '/v' + str(total_volume) + ')')
if np.any(mask[i] == 1):
ax.contour(mask[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if mask1 is not None:
ax = plots[i, 1]
mask_point_count = np.sum(mask1[i] == 1)
ax.set_title('contour z=' + str(z) + ' (p' + str(mask_point_count) + ')')
if spacing is not None:
volume = int(mask_point_count * single_volume)
total_volume = int(np.sum(mask1 == 1) * single_volume)
ax.set_title('contour z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
+ str(volume) + '/v' + str(total_volume) + ')')
if np.any(mask1[i] == 1):
ax.contour(mask1[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if mask2 is not None:
ax = plots[i, 2]
mask_point_count = np.sum(mask2[i] == 1)
ax.set_title('extend z=' + str(z) + ' (p' + str(mask_point_count) + ')')
if spacing is not None:
volume = int(mask_point_count * single_volume)
total_volume = int(np.sum(mask2 == 1) * single_volume)
ax.set_title('extend z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
+ str(volume) + '/v' + str(total_volume) + ')')
if np.any(mask2[i] >= 1):
ax.contour(mask2[i] >= 1, [0.5], colors='g', linewidths=1, alpha=0.75)
if np.any(mask2[i] == 1):
ax.contour(mask2[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
change_mask = mask[i] - mask2[i]
change_coords = np.asarray(np.where(change_mask == 1))
ax.scatter(change_coords[1, :], change_coords[0, :], color='b', s=1, alpha=1)
change_coords = np.asarray(np.where(change_mask == 255))
ax.scatter(change_coords[1, :], change_coords[0, :], color='g', s=1, alpha=1)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if mask3 is not None:
ax = plots[i, 3]
mask_point_count = np.sum(mask3[i] == 1)
ax.set_title('shrink z=' + str(z) + ' (p' + str(mask_point_count) + ')')
if spacing is not None:
volume = int(mask_point_count * single_volume)
total_volume = int(np.sum(mask3 == 1) * single_volume)
ax.set_title('shrink z=' + str(z) + ' (p' + str(mask_point_count) + '/v'
+ str(volume) + '/v' + str(total_volume) + ')')
if np.any(mask3[i] == 1):
ax.contour(mask3[i] == 1, [0.5], colors='r', linewidths=1, alpha=0.75)
change_mask = mask[i] - mask3[i]
change_coords = np.asarray(np.where(change_mask == 1))
ax.scatter(change_coords[1, :], change_coords[0, :], color='b', s=1, alpha=1)
change_coords = np.asarray(np.where(change_mask == 255))
ax.scatter(change_coords[1, :], change_coords[0, :], color='g', s=1, alpha=1)
# filter_data = ndimage.median_filter(data[i], 3)
# filter_data = data[i].copy()
# filter_data = ndimage.gaussian_filter(filter_data, sigma=2)
ax.imshow(data[i], cmap=cmap, vmin=vmin, vmax=vmax)
ax.axis('off')
if image_title is not None:
fig.suptitle(str(image_title), fontsize=18)
if save:
check_and_makedirs(file_path, is_file=True)
fig.savefig(file_path, dpi=90, bbox_inches='tight')
if not show:
plt.close()
if show:
plt.show()
def generate_video_one_image(data, interval=200, file_path=None):
"""
Given CT img, return an animation across axial slice
data: [D,H,W] or [D,H,W,3]
interval: interval between each slice, default 200
file_path: path to save the animation if not None, default None
return: matplotlib.animation.Animation
"""
fig = plt.figure()
imgs = []
for i in range(len(data)):
img = plt.imshow(data[i], animated=True)
imgs.append([img])
anim = animation.ArtistAnimation(fig, imgs, interval=interval, blit=True, repeat_delay=1000)
if file_path:
Writer = animation.writers['ffmpeg']
writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
anim.save(file_path)
return anim
def generate_video_two_image(data1, title1, data2, title2, interval=200, file_path=None):
vmin = -1350
vmax = 150
fig, axes = plt.subplots(1, 2)
imgs = []
for i in range(len(data1)):
ax = axes[0]
ax.set_title(title1)
img1 = ax.imshow(data1[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)
ax = axes[1]
ax.set_title(title2)
img2 = ax.imshow(data2[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)
imgs.append([img1, img2])
anim = animation.ArtistAnimation(fig, imgs, interval=interval, blit=True, repeat_delay=1000)
if file_path:
Writer = animation.writers['ffmpeg']
writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
anim.save(file_path)
return anim
def generate_video_predict(data, mask, truth, interval=200, file_path=None):
vmin = -1350
vmax = 150
fig, axes = plt.subplots(1, 2)
imgs = []
for i in range(len(data)):
ax = axes[0]
ax.set_title('predict')
coords = np.asarray(np.where(mask[i] == 1))
predict_mask = ax.scatter(coords[1, :], coords[0, :], color='r', s=1, alpha=0.5, animated=True)
predict_img = ax.imshow(data[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)
ax = axes[1]
ax.set_title('ground truth')
coords = np.asarray(np.where(truth[i] == 1))
gt_mask = ax.scatter(coords[1, :], coords[0, :], color='c', s=1, alpha=0.5, animated=True)
gt_img = ax.imshow(data[i], cmap='gray', vmin=vmin, vmax=vmax, animated=True)
imgs.append([predict_img, predict_mask, gt_img, gt_mask])
anim = animation.ArtistAnimation(fig, imgs, interval=interval, blit=True, repeat_delay=1000)
if file_path:
Writer = animation.writers['ffmpeg']
writer = Writer(fps=30, metadata=dict(artist='Me'), bitrate=1800)
anim.save(file_path)
return anim
def plot_shrink_extend_mask(data, one_nodule_mask, result_mask, extend_mask, shrink_mask, spacing, file_path):
coords = np.asarray(np.where(one_nodule_mask == 1))
if len(coords[0]) > 0:
coord_start = coords.min(axis=1)
coord_end = coords.max(axis=1) + 1
z_num = coord_end[0] - coord_start[0] + 4
image_height = max(coord_end[1] - coord_start[1], coord_end[2] - coord_start[2], 200)
crop_start, crop_end = get_crop_start_end(coord_start, coord_end, (z_num, image_height, image_height))
show_data = data[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
show_mask = one_nodule_mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
result_mask1 = result_mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
extend_mask1 = extend_mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
shrink_mask1 = shrink_mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
plot_image_multi_mask(show_data,
show_mask,
result_mask1,
extend_mask1,
shrink_mask1,
spacing=spacing,
z_start=crop_start[0],
stride=1,
show=False,
save=True,
file_path=file_path)
def plot_nodule_box_mask(data, one_nodule_mask, nodule_box, spacing, file_path):
coords = np.asarray(np.where(one_nodule_mask == 1))
if len(coords[0]) > 0:
coord_start = coords.min(axis=1)
coord_end = coords.max(axis=1) + 1
z_num = coord_end[0] - coord_start[0] + 4
if nodule_box is not None:
z_pixel = get_diameter_pixel_length(nodule_box.diameter, spacing[0])
z_num = max(z_num, z_pixel + 4)
image_height = max(coord_end[1] - coord_start[1], coord_end[2] - coord_start[2], 200)
crop_start, crop_end = get_crop_start_end(coord_start, coord_end, (z_num, image_height, image_height))
show_data = data[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
show_mask = one_nodule_mask[crop_start[0]:crop_end[0],
crop_start[1]:crop_end[1],
crop_start[2]:crop_end[2]]
if nodule_box is not None:
box = get_nodule_rect(nodule_box, spacing)
box[:, 0] = box[:, 0] - crop_start
box[:, 1] = box[:, 1] - crop_start
else:
box = np.zeros((3, 2), np.int)
box[:, 0] = coord_start - crop_start
box[:, 1] = coord_end - crop_start
plot_image_and_mask(show_data,
show_mask,
rect=box,
spacing=spacing,
show_whole_hist=True,
z_start=crop_start[0],
stride=1,
show=False,
save=True,
file_path=file_path)
# -*- coding: utf-8 -*-
import os, sys
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
import time
import logging
import numpy as np
import SimpleITK as sitk
import pydicom
import traceback
from concurrent.futures import ThreadPoolExecutor
from multiprocessing.pool import Pool
from data.data_process_utils.test_box_utils import NoduleBox, nodule_raw2standard
from data.data_process_utils.test_data_utils import check_and_makedirs
from data.data_process_utils.test_data_utils import clip_data, downsample_data, upsample_mask, resample_data, resample_mask
import base64
import cv2
import ast
def lung_segment(data_npy):
def morphology_opening_2d(mask, data_npy, pool_size=5):
mask_npy = np.zeros(data_npy.shape, data_npy.dtype)
chest_mask_npy = np.zeros(data_npy.shape, data_npy.dtype)
def morphology_opening_2d_task(i, mask_2d):
mask_slice = sitk.BinaryMorphologicalOpening(mask_2d, 2)
mask_npy[i] = sitk.GetArrayFromImage(mask_slice)
chest_mask_npy[i] = sitk.GetArrayFromImage(sitk.BinaryFillhole(mask_slice))
with ThreadPoolExecutor(pool_size) as executor:
for i in range(data_npy.shape[0]):
executor.submit(morphology_opening_2d_task, i, mask[:, :, i])
return sitk.GetImageFromArray(mask_npy), sitk.GetImageFromArray(chest_mask_npy)
def morphology_closing_2d(mask, data_npy, pool_size=5):
mask_npy = np.zeros(data_npy.shape, data_npy.dtype)
def morphology_closing_2d_task(i, mask_2d):
mask_slice = sitk.BinaryMorphologicalClosing(mask_2d, 15)
mask_slice = sitk.BinaryDilate(mask_slice, 15)
mask_npy[i] = sitk.GetArrayFromImage(sitk.BinaryFillhole(mask_slice))
mask_npy[i] = np.bitwise_or(mask_npy[i], mask_npy[i][:, ::-1])
with ThreadPoolExecutor(pool_size) as executor:
for i in range(data_npy.shape[0]):
executor.submit(morphology_closing_2d_task, i, mask[:, :, i])
return sitk.GetImageFromArray(mask_npy)
data_npy = data_npy.astype(np.int)
data_npy = clip_data(data_npy)
data = sitk.GetImageFromArray(data_npy)
mask = 1 - sitk.OtsuThreshold(data)
mask, chest_mask = morphology_opening_2d(mask, data_npy)
lung_mask = sitk.Subtract(chest_mask, mask)
# Remove areas not in the chest, when CT covers regions below the chest
eroded_mask = sitk.BinaryErode(lung_mask, 10)
seed_npy = sitk.GetArrayFromImage(eroded_mask)
seed_npy = np.array(seed_npy.nonzero())[[2, 1, 0]]
seeds = seed_npy.T.tolist()
lung_mask = sitk.ConfidenceConnected(lung_mask, seeds, multiplier=2.5)
lung_mask = morphology_closing_2d(lung_mask, data_npy)
return sitk.GetArrayFromImage(lung_mask)
def lung_segment_enhance(data):
z_max = max(int(np.ceil(data.shape[0] / 100)), 2)
if data.shape[1] <= 512 and data.shape[2] <= 512:
scale = (z_max, 2, 2)
else:
scale = (z_max, 4, 4)
new_data_shape = np.array(data.shape) // scale * scale
new_data = data[:new_data_shape[0], :new_data_shape[1], :new_data_shape[2]]
new_data = downsample_data(new_data, scale)
new_data = new_data.astype(data.dtype)
new_mask = lung_segment(new_data)
mask = upsample_mask(new_mask, scale)
pad = np.zeros((3, 2), np.int)
pad[:, 1] = np.array(data.shape) - np.array(mask.shape)
mask = np.pad(mask, pad, mode='edge')
return mask
def get_lung_mask_and_box(data, uid, segment=False, segment_data=False, segment_margin=[0, 0, 0], min_points=10000):
lung_box = np.zeros((3, 2), np.int)
lung_box[:, 1] = data.shape
found_lung = False
if segment:
try:
start_time = time.time()
lung_mask = lung_segment_enhance(data)
logging.info('{}, {}, Lung segment run time {:.2f}(s)'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), uid, (time.time() - start_time)))
if np.sum(lung_mask == 1) > min_points:
found_lung = True
if found_lung and segment_data:
segment_margin = np.array(segment_margin)
coords = np.asarray(np.where(lung_mask == 1))
lung_box[:, 0] = np.maximum(coords.min(axis=1) - segment_margin, 0)
lung_box[:, 1] = np.minimum(coords.max(axis=1) + segment_margin + 1, data.shape)
data = data[lung_box[0, 0]:lung_box[0, 1],
lung_box[1, 0]:lung_box[1, 1],
lung_box[2, 0]:lung_box[2, 1]]
lung_mask = lung_mask[lung_box[0, 0]:lung_box[0, 1],
lung_box[1, 0]:lung_box[1, 1],
lung_box[2, 0]:lung_box[2, 1]]
except Exception as e:
traceback.print_exc()
if not found_lung:
lung_mask = np.ones(data.shape, np.uint8)
return data, lung_mask, lung_box
def transform_file_type(input_file, output_file):
image = sitk.ReadImage(input_file)
sitk.WriteImage(image, output_file)
def load_single_dicom(input_file):
image = sitk.ReadImage(input_file)
slice = sitk.GetArrayFromImage(image)[0, :, :]
return slice
class CTSeries(object):
def __init__(self):
self._PatientID = None
self._SeriesInstanceUID = None
self._SOPInstanceUIDs = None
self._ReconstructionDiameter = None
self._Rows = None
self._Columns = None
self._AcquisitionDate = None
self._Manufacturer = None
self._InstitutionName = None
self._raw_data = None
self._raw_spacing = None
self._raw_origin = None
self._raw_direction = None
self._dicoms_is_loaded = False
self._lung_mask = None
self._lung_box = None
self._standard_data = None
self._standard_spacing = None
self._dicoms_is_preprocessed = False
self._raw_labels = None
self._standard_labels = None
self._label_is_loaded = False
self._standard_is_loaded = False
def load_dicoms(self, dicom_dir_path):
logging.info('{}, Loading dicoms from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), dicom_dir_path))
dicom_names = [f for f in os.listdir(dicom_dir_path) if '.xml' not in f]
dicom_paths = list(map(lambda x: os.path.join(dicom_dir_path, x), dicom_names))
dicoms = list(map(lambda x: pydicom.read_file(x, stop_before_pixels=True), dicom_paths))
try:
slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
except AttributeError:
slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
# sort slices by their z coordinates from large to small
if dicoms[0].get('PatientPosition') is None:
patient_position = 'HFS'
else:
patient_position = dicoms[0].PatientPosition
if patient_position in ['FFP', 'FFS']:
idx_z_sorted = np.argsort(slice_locations)[::-1]
else:
idx_z_sorted = np.argsort(slice_locations)[::-1]
dicom_paths = np.asarray(dicom_paths)[idx_z_sorted]
self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID
self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))[idx_z_sorted]
try:
self._PatientID = dicoms[0].PatientID
self._ReconstructionDiameter = dicoms[0].ReconstructionDiameter
self._Rows = dicoms[0].Rows
self._Columns = dicoms[0].Columns
self._AcquisitionDate = dicoms[0].AcquisitionDate
self._Manufacturer = dicoms[0].Manufacturer
self._InstitutionName = dicoms[0].InstitutionName
except AttributeError:
self._PatientID = None
self._ReconstructionDiameter = None
self._Rows = None
self._Columns = None
self._AcquisitionDate = None
self._Manufacturer = None
self._InstitutionName = None
reader = sitk.ImageSeriesReader()
reader.SetFileNames(dicom_paths)
image_itk = reader.Execute()
# all in [z, y, x] order
self._raw_data = sitk.GetArrayFromImage(image_itk)
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_direction = image_itk.GetDirection()
self._dicoms_is_loaded = True
def load_dicoms_mp(self, dicom_dir_path):
logging.info('{}, Loading dicoms from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), dicom_dir_path))
dicom_names = [f for f in os.listdir(dicom_dir_path) if '.xml' not in f]
dicom_paths = list(map(lambda x: os.path.join(dicom_dir_path, x), dicom_names))
dicoms = list(map(lambda x: pydicom.read_file(x, stop_before_pixels=True), dicom_paths))
try:
slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
except AttributeError:
slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
# sort slices by their z coordinates from large to small
if dicoms[0].get('PatientPosition') is None:
patient_position = 'HFS'
else:
patient_position = dicoms[0].PatientPosition
if patient_position in ['FFP', 'FFS']:
idx_z_sorted = np.argsort(slice_locations)[::-1]
else:
idx_z_sorted = np.argsort(slice_locations)[::-1]
dicom_paths = np.asarray(dicom_paths)[idx_z_sorted]
self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID
self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))[idx_z_sorted]
try:
self._PatientID = dicoms[0].PatientID
self._ReconstructionDiameter = dicoms[0].ReconstructionDiameter
self._Rows = dicoms[0].Rows
self._Columns = dicoms[0].Columns
self._AcquisitionDate = dicoms[0].AcquisitionDate
self._Manufacturer = dicoms[0].Manufacturer
self._InstitutionName = dicoms[0].InstitutionName
except AttributeError:
self._PatientID = None
self._ReconstructionDiameter = None
self._Rows = None
self._Columns = None
self._AcquisitionDate = None
self._Manufacturer = None
self._InstitutionName = None
pool = Pool(processes=20)
slice_list = pool.map(load_single_dicom, dicom_paths)
pool.close()
pool.join()
raw_data = np.zeros((len(dicom_paths), dicoms[0].Rows, dicoms[0].Columns))
for idx, slice_data in enumerate(slice_list):
raw_data[idx] = slice_data
image0 = sitk.ReadImage(dicom_paths[0])
dicom0 = pydicom.read_file(dicom_paths[0], stop_before_pixels=True)
# all in [z, y, x] order
self._raw_data = raw_data
self._raw_spacing = np.array(list(reversed(image0.GetSpacing())))
self._raw_origin = np.array(list(reversed(dicom0.ImagePositionPatient)))
self._raw_direction = image0.GetDirection()
self._dicoms_is_loaded = True
def load_single_file(self, file_path, uid=None):
logging.info('{}, Loading file from {}...'.format(time.strftime('%Y-%m-%d %H:%M:%S'), file_path))
self._SeriesInstanceUID = uid
if self._SeriesInstanceUID is None:
self._SeriesInstanceUID = os.path.splitext(os.path.basename(file_path))[0]
image_itk = sitk.ReadImage(file_path)
# all in [z, y, x] order
self._raw_data = sitk.GetArrayFromImage(image_itk)
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_direction = image_itk.GetDirection()
self._dicoms_is_loaded = True
def load_file(self, file_path, uid=None):
if os.path.splitext(os.path.basename(file_path))[-1] in ['.mhd']:
self.load_single_file(file_path, uid)
else:
self.load_dicoms(file_path)
def is_hrct(self):
"""
是否是高清扫描
"""
if self._Rows is not None and self._Rows >= 1024:
return True
return False
def is_ultra_hrct(self, max_reconstruction_diameter=250):
"""
是否是高清靶扫描
"""
if self._Rows is not None and self._Rows >= 1024 and \
self._ReconstructionDiameter is not None and \
self._ReconstructionDiameter < max_reconstruction_diameter:
return True
return False
def get_patient_id(self):
return self._PatientID
def get_series_instance_uid(self):
return self._SeriesInstanceUID
def get_reconstruction_diameter(self):
return self._ReconstructionDiameter
def get_rows(self):
return self._Rows
def get_columns(self):
return self._Columns
def get_acquisition_date(self):
return self._AcquisitionDate
def get_manufacturer(self):
return self._Manufacturer
def get_institution_name(self):
return self._InstitutionName
def get_raw_data(self):
return self._raw_data
def get_raw_spacing(self):
return self._raw_spacing
def get_raw_origin(self):
return self._raw_origin
def get_raw_direction(self):
return self._raw_direction
def get_lung_mask(self):
return self._lung_mask
def get_lung_box(self):
return self._lung_box
def get_standard_data(self):
return self._standard_data
def get_standard_spacing(self):
return self._standard_spacing
def get_raw_labels(self):
return self._raw_labels
def get_standard_labels(self):
return self._standard_labels
def save_raw_data(self, output_file):
if self._dicoms_is_loaded:
image = sitk.GetImageFromArray(self._raw_data)
image.SetSpacing(np.array(list(reversed(self._raw_spacing))))
image.SetOrigin(np.array(list(reversed(self._raw_origin))))
image.SetDirection(self._raw_direction)
check_and_makedirs(output_file, is_file=True)
sitk.WriteImage(image, output_file)
def world_to_voxel_coord(self, world_coord):
voxel_coord = np.absolute(world_coord - self._raw_origin) / self._raw_spacing
return voxel_coord
def voxel_to_world_coord(self, voxel_coord):
world_coord = voxel_coord * self._raw_spacing + self._raw_origin
return world_coord
def preprocess(self, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing ...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def preprocess_1024u(self, check_spacing=False, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing 1024u...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
if check_spacing and \
(float(self._raw_spacing[0]) > 1.5 or self._ReconstructionDiameter is None or
float(self._ReconstructionDiameter) > 190):
logging.info('{}, {}, Preprocessing 1024u resample data...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
# z轴间隔大于1.5mm或者FOV为空或者大于190mm,重采样z轴间隔为1mm与FOV为180mm
new_spacing = np.array([1.0, 0.17578125, 0.17578125]) * np.array(scale)
standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing)
else:
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def preprocess_1024(self, check_spacing=False, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing 1024...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
if check_spacing and \
float(self._raw_spacing[0]) > 1.5:
logging.info('{}, {}, Preprocessing 1024 resample data...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
# z轴间隔大于1.5mm,重采样z轴间隔为1mm
new_spacing = np.array([1.0, 0.3515625, 0.3515625]) * np.array(scale)
standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing)
else:
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def preprocess_512(self, check_spacing=False, segment=False, scale=(1, 1, 1)):
logging.info('{}, {}, Preprocessing 512...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
if check_spacing and \
float(self._raw_spacing[0]) > 1.5:
logging.info('{}, {}, Preprocessing 512 resample data...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID))
# z轴间隔大于1.5mm,重采样z轴间隔为1mm
new_spacing = np.array([1.0, 0.703125, 0.703125]) * np.array(scale)
standard_data, standard_spacing = resample_data(self._raw_data, self._raw_spacing, new_spacing)
else:
standard_data = downsample_data(self._raw_data, scale)
standard_spacing = self._raw_spacing * np.array(scale)
standard_data, lung_mask, lung_box = \
get_lung_mask_and_box(standard_data, self.get_series_instance_uid(), segment=segment)
self._lung_mask = lung_mask
self._lung_box = lung_box
self._standard_data = standard_data
self._standard_spacing = standard_spacing
self._dicoms_is_preprocessed = True
def load_labels(self, label_path):
"""
Load labels from label_path.
label_path: path to the label file, which is a csv with 5 fields:
[z, y, x, diameter, is_pos]
"""
if not self._dicoms_is_loaded:
raise Exception('DICOM files have not been loaded yet')
if not self._dicoms_is_preprocessed:
raise Exception('DICOM files have not been preprocessed yet')
logging.info('{}, {}, Loading labels from {}...'.format(
time.strftime("%Y-%m-%d %H:%M:%S"), self._SeriesInstanceUID, label_path))
nodule_boxes = []
with open(label_path) as f:
for line in f:
z, y, x, diameter, is_pos = \
line.strip('\n').replace('"', '').split(',')[0:5]
nodule_boxes.append(
NoduleBox(float(z), float(y), float(x), float(diameter), float(is_pos), 1.0))
self._raw_labels = nodule_boxes
self._standard_labels = nodule_raw2standard(
self._raw_labels, self._raw_spacing, self._standard_spacing, start=self._lung_box[:, 0])
self._label_is_loaded = True
def load_nodule_boxes(self, nodule_boxes):
"""
Load labels from nodule_boxes.
"""
if not self._dicoms_is_loaded:
raise Exception('DICOM files have not been loaded yet')
if not self._dicoms_is_preprocessed:
raise Exception('DICOM files have not been preprocessed yet')
self._raw_labels = nodule_boxes
self._standard_labels = nodule_raw2standard(
self._raw_labels, self._raw_spacing, self._standard_spacing, start=self._lung_box[:, 0])
self._label_is_loaded = True
def save_standard_npy(self, npy_path, uid):
"""
Save *_standard data in numpy format.
npy_path: str, path to save.
uid: str, prefix for the files.
"""
if not self._dicoms_is_preprocessed:
raise Exception('DICOM files have not been preprocessed yet')
check_and_makedirs(npy_path)
np.save(os.path.join(npy_path, uid + '_standard_data.npy'),
self._standard_data.astype(np.float32))
np.save(os.path.join(npy_path, uid + '_standard_spacing.npy'),
self._standard_spacing.astype(np.float32))
# 检查标注是否错误
standard_data_shape = np.array(self._standard_data.shape)
standard_labels = []
for i in range(len(self._standard_labels)):
nodule_box = self._standard_labels[i]
if not (0 <= nodule_box.z <= standard_data_shape[0] - 1 and
0 <= nodule_box.y <= standard_data_shape[1] - 1 and
0 <= nodule_box.x <= standard_data_shape[2] - 1):
logging.error('{}, {}, label index={} error.'.format(time.strftime("%Y-%m-%d %H:%M:%S"), uid, i))
standard_labels.append(np.array(self._standard_labels[i]))
standard_labels = np.array(standard_labels)
np.save(os.path.join(npy_path, uid + '_standard_labels.npy'),
standard_labels.astype(np.float32))
def load_standard_npy(self, npy_path, uid, mmap_mode='r'):
"""
Load *_standard data in numpy format.
npy_path: str, path to load.
uid: str, prefix for the files.
"""
if self._dicoms_is_preprocessed:
raise Exception('DICOM files have already been preprocessed')
self._SeriesInstanceUID = uid
self._standard_data = np.load(
os.path.join(npy_path, uid + '_standard_data.npy'), mmap_mode=mmap_mode)
self._standard_spacing = np.load(
os.path.join(npy_path, uid + '_standard_spacing.npy'))
self._standard_labels = []
standard_labels = np.load(
os.path.join(npy_path, uid + '_standard_labels.npy'))
for i in range(len(standard_labels)):
z, y, x, diameter, is_pos = standard_labels[i][0:5]
self._standard_labels.append(
NoduleBox(z, y, x, diameter, is_pos, 1.0))
self._standard_is_loaded = True
def save_standard_mask_npy(self, npy_path, uid, nodule_index, mask):
"""
保存mask
"""
check_and_makedirs(npy_path)
standard_mask = mask
if np.any(self._standard_spacing != self._raw_spacing):
standard_mask = resample_mask(mask, self._standard_data.shape)
lung_box = self._lung_box
standard_mask = standard_mask[lung_box[0, 0]:lung_box[0, 1],
lung_box[1, 0]:lung_box[1, 1],
lung_box[2, 0]:lung_box[2, 1]]
np.save(os.path.join(npy_path, uid + '_' + str(nodule_index) + '_standard_mask.npy'),
standard_mask.astype(np.uint8))
def load_standard_mask_npy(self, npy_path, uid, nodule_index, mmap_mode='r'):
"""
加载mask
"""
standard_mask = np.load(
os.path.join(npy_path, uid + '_' + str(nodule_index) + '_standard_mask.npy'), mmap_mode=mmap_mode)
return standard_mask
def base64_to_list(base64_str):
indexs = ''
list = []
img_np = None
if base64_str:
time_now = time.time()
img_data = base64.b64decode(base64_str[22:])
nparr = np.fromstring(img_data, np.uint8)
img_np = cv2.imdecode(nparr, 0)
img_np[img_np != 0] = 1
point_list = np.where(img_np != 0)
if len(point_list[0]) > 0:
y_list = point_list[0]
x_list = point_list[1]
indexs = '{'
for point_idx, x in enumerate(x_list):
raw_x = x
raw_y = y_list[point_idx]
map = {'x': raw_x, 'y': raw_y}
list.append(map)
indexs = indexs + '[%s,%s],' % (raw_x, raw_y)
indexs = indexs[:-1] + '}'
print('run time {:.5f}(s)'.format(time.time() - time_now))
return list, indexs, img_np
def meta_to_list(meta, img_np=np.zeros((1, 1))):
meta_dict = ast.literal_eval(meta)
x = meta_dict['x']
y = meta_dict['y']
w = meta_dict['w']
h = meta_dict['h']
delineation = meta_dict['delineation']
mask_str = ''
list = []
indexs = '{'
for e in delineation:
mask_str = mask_str + str((bin(((1 << 32) - 1) & int(e))[2:]).zfill(32))
for idx, s in enumerate(mask_str):
if s == '1':
x_index = x + int(idx % w)
y_index = y + int(idx / w)
if len(img_np)>1:
img_np[y_index][x_index] = 1
map = {'x': x_index, 'y': y_index}
list.append(map)
indexs = indexs + '[%s,%s],' % (x_index, y_index)
return list, indexs, img_np
import torch
from torch.utils.data import Dataset
import sys
import os
import numpy as np
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from cls_utils.data import load_data_csv, get_data_and_label
from cls_utils.utils import normalize, hu_value_to_uint8, hu_normalize
class SubjectDataset(Dataset):
def __init__(self, np_path, csv_path, is_train=True, is_2d=False, augment=False, permute=False):
self._npy_path = np_path
self._csv_path = csv_path
#self._cfg = cfg
self._is_train = is_train
self._is_2d = is_2d
self.augment = augment
self.permute = permute
self._preprocess()
def _preprocess(self):
self.subject_ids = load_data_csv(self._csv_path)
def __len__(self):
return self.subject_ids.shape[0]
def _get_data(self, i):
idx = i % len(self.subject_ids)
#获取当前指定索引对应的npy文件的路径
subject_id = self.subject_ids.iloc[idx, 0]
label = int(self.subject_ids.iloc[idx, 1])
data = get_data_and_label(self._npy_path, subject_id, self._is_2d, self.augment, self.permute)
#print(data.shape)
return data, np.asarray([label])
def __getitem__(self, idx):
data, label = self._get_data(idx)
if data is None or label is None:
data, label = self._get_data(0)
data = hu_normalize(data)
#data = normalize(hu_value_to_uint8(data))
return torch.from_numpy(data.copy()), torch.tensor(label).type(torch.float32)
\ No newline at end of file
import copy
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from functools import lru_cache
def custom_collate_fn_2d(batch):
label_id, npy_data_2d, label, z_index, rotate_count = zip(*batch)
label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
npy_data_2d = torch.cat(npy_data_2d, dim=0)
label = torch.cat(label, dim=0)
return label_id, npy_data_2d, label, z_index, rotate_count
def custom_collate_fn_2d_error(batch):
pass
def custom_collate_fn_3d(batch):
label_id, npy_data_3d, label, z_index, rotate_count = zip(*batch)
label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
npy_data_3d = torch.cat(npy_data_3d, dim=0)
label = torch.cat(label, dim=0)
return label_id, npy_data_3d, label, z_index, rotate_count
def custom_collate_fn_3d_error(batch):
pass
def custom_collate_fn_2d3d(batch):
label_id_2d, npy_data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, npy_data_3d, label_3d, z_index_3d, rotate_count_3d = zip(*batch)
label_id_2d = [idx_label_id for sub_label_id in label_id_2d for idx_label_id in sub_label_id]
label_id_3d = [idx_label_id for sub_label_id in label_id_3d for idx_label_id in sub_label_id]
z_index_2d = [idx_z_index for sub_z_index in z_index_2d for idx_z_index in sub_z_index]
z_index_3d = [idx_z_index for sub_z_index in z_index_3d for idx_z_index in sub_z_index]
rotate_count_2d = [idx_rotate_count for sub_rotate_count in rotate_count_2d for idx_rotate_count in sub_rotate_count]
rotate_count_3d = [idx_rotate_count for sub_rotate_count in rotate_count_3d for idx_rotate_count in sub_rotate_count]
npy_data_2d = torch.cat(npy_data_2d, dim=0)
npy_data_3d = torch.cat(npy_data_3d, dim=0)
label_2d = torch.cat(label_2d, dim=0)
label_3d = torch.cat(label_3d, dim=0)
return label_id_2d, npy_data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, npy_data_3d, label_3d, z_index_3d, rotate_count_3d
def custom_collate_fn_2d3d_error(batch):
pass
def custom_collate_fn_s3d(batch):
label_id, npy_data, label, z_index, rotate_count = zip(*batch)
label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
npy_data = torch.cat(npy_data, dim=0)
label = torch.cat(label, dim=0)
return label_id, npy_data, label, z_index, rotate_count
def custom_collate_fn_s3d_error(batch):
pass
def custom_collate_fn_resnet3d(batch):
label_id, npy_data, label, z_index, rotate_count = zip(*batch)
label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
npy_data = torch.cat(npy_data, dim=0)
label = torch.cat(label, dim=0)
return label_id, npy_data, label, z_index, rotate_count
def custom_collate_fn_resnet3d_error(batch):
pass
def custom_collate_fn_d2d(batch):
label_id, npy_data_2d, label, z_index, rotate_count = zip(*batch)
label_id = [idx_label_id for sub_label_id in label_id for idx_label_id in sub_label_id]
z_index = [idx_z_index for sub_z_index in z_index for idx_z_index in sub_z_index]
rotate_count = [idx_rotate_count for sub_rotate_count in rotate_count for idx_rotate_count in sub_rotate_count]
npy_data_2d = torch.cat(npy_data_2d, dim=0)
label = torch.cat(label, dim=0)
return label_id, npy_data_2d, label, z_index, rotate_count
def custom_collate_fn_d2d_error(batch):
pass
class ClassificationDataset2d(Dataset):
def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
print(f"{data_info} lens: {len(self.df)}")
self.node_column = node_column
self.label_id_column = label_id_column
self.npy_file_column = npy_file_column
self.class_column = class_column
self.z_index_column = z_index_column
self.rotate_column = rotate_column
self.lens = len(self.df)
def __len__(self):
return self.lens
@lru_cache(maxsize=10)
def load_npy_data(self, npy_data_path):
data = np.load(npy_data_path) # [256, 256]
data = data[np.newaxis, np.newaxis, :, :] # [1, 1,256, 256]
return data
def __getitem__(self, idx):
npy_file = self.df.loc[idx, self.npy_file_column]
node = self.df.loc[idx, self.node_column]
label_id = self.df.loc[idx, self.label_id_column]
label = self.df.loc[idx, self.class_column]
z_index = self.df.loc[idx, self.z_index_column]
rotate_count = self.df.loc[idx, self.rotate_column]
npy_data = self.load_npy_data(npy_file)
label_id_str = f"{node}_{label_id}"
data_2d = torch.from_numpy(npy_data).to(torch.float32)
label = torch.tensor([label], dtype=torch.float32)
return [label_id_str], data_2d, label, [z_index], [rotate_count]
class ClassificationDatasetError2d(Dataset):
pass
class ClassificationDataset3d(Dataset):
def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
print(f"{data_info} lens: {len(self.df)}")
self.node_column = node_column
self.label_id_column = label_id_column
self.npy_file_column = npy_file_column
self.class_column = class_column
self.z_index_column = z_index_column
self.rotate_column = rotate_column
self.lens = len(self.df)
def __len__(self):
return self.lens
@lru_cache(maxsize=10)
def load_npy_data(self, npy_data_path):
data = np.load(npy_data_path) # [48, 256, 256]
print(f"npy_data_path: {npy_data_path}, data.shape: {data.shape}")
data = data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
return data
def __getitem__(self, idx):
npy_file = self.df.loc[idx, self.npy_file_column]
node = self.df.loc[idx, self.node_column]
label_id = self.df.loc[idx, self.label_id_column]
label = self.df.loc[idx, self.class_column]
z_index = self.df.loc[idx, self.z_index_column]
rotate_count = self.df.loc[idx, self.rotate_column]
npy_data = self.load_npy_data(npy_file)
label_id_str = f"{node}_{label_id}"
data_3d = torch.from_numpy(npy_data).to(torch.float32)
label = torch.tensor([label], dtype=torch.float32)
return [label_id_str], data_3d, label, [z_index], [rotate_count]
class ClassificationDatasetError3d(Dataset):
pass
class ClassificationDataset2d3d(Dataset):
def __init__(self, data_2d_csv_file, data_3d_csv_file, label_id_column="label_id", npy_file_column="npy_file", class_column="label", node_column="node", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
self.df_2d = pd.read_csv(data_2d_csv_file, header=0, encoding="utf-8")
self.df_3d = pd.read_csv(data_3d_csv_file, header=0, encoding="utf-8")
print(f"{data_info} lens: {len(self.df_2d)}, {len(self.df_3d)}")
self.label_id_column = label_id_column
self.npy_file_column = npy_file_column
self.node_column = node_column
self.class_column = class_column
self.z_index_column = z_index_column
self.rotate_column = rotate_column
self.lens_data_2d = len(self.df_2d)
self.lens_data_3d = len(self.df_3d)
self.lens = max(self.lens_data_2d, self.lens_data_3d)
def __len__(self):
return self.lens
@lru_cache(maxsize=10)
def load_npy_data(self, npy_data_path):
print(f"npy_data_path: {npy_data_path}")
return np.load(npy_data_path)
def __getitem__(self, idx):
index_data_2d = idx if idx < self.lens_data_2d else idx % self.lens_data_2d
index_data_3d = idx if idx < self.lens_data_3d else idx % self.lens_data_3d
npy_data_path_2d = self.df_2d.loc[index_data_2d, self.npy_file_column]
npy_data_path_3d = self.df_3d.loc[index_data_3d, self.npy_file_column]
node_2d = self.df_2d.loc[index_data_2d, self.node_column]
node_3d = self.df_3d.loc[index_data_3d, self.node_column]
label_id_2d = self.df_2d.loc[index_data_2d, self.label_id_column]
label_id_3d = self.df_3d.loc[index_data_3d, self.label_id_column]
label_2d = self.df_2d.loc[index_data_2d, self.class_column]
label_3d = self.df_3d.loc[index_data_3d, self.class_column]
data_2d_z_index = self.df_2d.loc[index_data_2d, self.z_index_column]
data_2d_rotate_count = self.df_2d.loc[index_data_2d, self.rotate_column]
data_3d_z_index = self.df_3d.loc[index_data_3d, self.z_index_column]
data_3d_rotate_count = self.df_3d.loc[index_data_3d, self.rotate_column]
npy_data_2d = self.load_npy_data(npy_data_path_2d)
npy_data_3d = self.load_npy_data(npy_data_path_3d)
npy_data_2d = npy_data_2d[np.newaxis, np.newaxis, :, :] # [1, 1, 256, 256]
npy_data_3d = npy_data_3d[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
label_id_2d_str = f"{node_2d}_{label_id_2d}"
label_id_3d_str = f"{node_3d}_{label_id_3d}"
data_2d = torch.from_numpy(npy_data_2d).to(torch.float32)
data_3d = torch.from_numpy(npy_data_3d).to(torch.float32)
label_2d = torch.tensor([label_2d], dtype=torch.float32)
label_3d = torch.tensor([label_3d], dtype=torch.float32)
return [label_id_2d_str], data_2d, label_2d, [data_2d_z_index], [data_2d_rotate_count], [label_id_3d_str], data_3d, label_3d, [data_3d_z_index], [data_3d_rotate_count]
class ClassificationDatasetError2d3d(Dataset):
pass
class ClassificationDatasetS3d(Dataset):
def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
print(f"{data_info} lens: {len(self.df)}")
self.node_column = node_column
self.label_id_column = label_id_column
self.npy_file_column = npy_file_column
self.class_column = class_column
self.z_index_column = z_index_column
self.rotate_column = rotate_column
self.lens = len(self.df)
def __len__(self):
return self.lens
@lru_cache(maxsize=10)
def load_npy_data(self, npy_data_path):
npy_data = np.load(npy_data_path)
print(f"npy_data_path: {npy_data_path}, npy_data.shape: {npy_data.shape}")
npy_data = npy_data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
return npy_data
def __getitem__(self, idx):
npy_file = self.df.loc[idx, self.npy_file_column]
node = self.df.loc[idx, self.node_column]
label_id = self.df.loc[idx, self.label_id_column]
label = self.df.loc[idx, self.class_column]
z_index = self.df.loc[idx, self.z_index_column]
rotate_count = self.df.loc[idx, self.rotate_column]
npy_data = self.load_npy_data(npy_file)
label_id_str = f"{node}_{label_id}"
data_3d = torch.from_numpy(npy_data).to(torch.float32)
label = torch.tensor([label], dtype=torch.float32)
return [label_id_str], data_3d, label, [z_index], [rotate_count]
class ClassificationDatasetErrorS3d(Dataset):
pass
class ClassificationDatasetResnet3d(Dataset):
def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
print(f"{data_info} lens: {len(self.df)}")
self.node_column = node_column
self.label_id_column = label_id_column
self.npy_file_column = npy_file_column
self.class_column = class_column
self.z_index_column = z_index_column
self.rotate_column = rotate_column
self.lens = len(self.df)
def __len__(self):
return self.lens
@lru_cache(maxsize=10)
def load_npy_data(self, npy_data_path):
npy_data = np.load(npy_data_path)
npy_data = npy_data[np.newaxis, np.newaxis, :, :, :] # [1, 1, 48, 256, 256]
return npy_data
def __getitem__(self, idx):
npy_file = self.df.loc[idx, self.npy_file_column]
node = self.df.loc[idx, self.node_column]
label_id = self.df.loc[idx, self.label_id_column]
label = self.df.loc[idx, self.class_column]
z_index = self.df.loc[idx, self.z_index_column]
rotate_count = self.df.loc[idx, self.rotate_column]
npy_data = self.load_npy_data(npy_file)
label_id_str = f"{node}_{label_id}"
data_3d = torch.from_numpy(npy_data).to(torch.float32)
label = torch.tensor([label], dtype=torch.float32)
return [label_id_str], data_3d, label, [z_index], [rotate_count]
class ClassificationDatasetErrorResnet3d(Dataset):
pass
class ClassificationDatasetD2d(Dataset):
def __init__(self, csv_file, node_column="node", label_id_column="label_id", npy_file_column="npy_file", class_column="label", z_index_column = "z_index", rotate_column = "rotate_count", data_info="train_dataset"):
self.df = pd.read_csv(csv_file, header=0, encoding="utf-8")
print(f"{data_info} lens: {len(self.df)}")
self.node_column = node_column
self.label_id_column = label_id_column
self.npy_file_column = npy_file_column
self.class_column = class_column
self.z_index_column = z_index_column
self.rotate_column = rotate_column
self.lens = len(self.df)
def __len__(self):
return self.lens
@lru_cache(maxsize=10)
def load_npy_data(self, npy_data_path):
npy_data = np.load(npy_data_path) # [3, 256, 256]
print(f"npy_data_path: {npy_data_path}, npy_data.shape: {npy_data.shape}")
npy_data = npy_data[np.newaxis, :, :, :] # [1, 3, 256, 256]
return npy_data
def __getitem__(self, idx):
npy_file = self.df.loc[idx, self.npy_file_column]
node_time = self.df.loc[idx, self.node_column]
label_id = self.df.loc[idx, self.label_id_column]
label = self.df.loc[idx, self.class_column]
z_index = self.df.loc[idx, self.z_index_column]
rotate_count = self.df.loc[idx, self.rotate_column]
npy_data = self.load_npy_data(npy_file)
label_id_str = f"{node_time}_{label_id}"
data_2d = torch.from_numpy(npy_data).to(torch.float32)
label = torch.tensor([label], dtype=torch.float32)
return [label_id_str], data_2d, label, [z_index], [rotate_count]
class ClassificationDatasetErrorD2d(Dataset):
pass
def cls_report_dict_to_string(report_dict):
lines = []
# print(f"cls_report_dict_to_string, report_dict: {report_dict}")
# 添加每类的指标
for label, metrics in report_dict.items():
if isinstance(metrics, dict):
line = f"{label:<15} {metrics['precision']:<10.2f} {metrics['recall']:<10.2f} {metrics['f1-score']:<10.2f} {metrics['support']:<10.2f}\n"
lines.append(line)
# 计算总的支持数
total_support = sum(metrics['support'] for metrics in report_dict.values() if isinstance(metrics, dict))
# 添加总体指标
accuracy_line = f"{'accuracy':<15} {'':<10} {'':<10} {report_dict['accuracy']:<10.2f} {total_support:<10.2f}\n"
lines.append(accuracy_line)
# 添加 macro avg 和 weighted avg
for key in ['macro avg', 'weighted avg']:
metrics = report_dict[key]
line = f"{key:<15} {metrics['precision']:<10.2f} {metrics['recall']:<10.2f} {metrics['f1-score']:<10.2f} {metrics['support']:<10.2f}\n"
lines.append(line)
line_first = f"{'':<15} {'precision':<10} {'recall':<10} {'f1-score':<10} {'support':<10}\n"
lines.insert(0, line_first)
return ''.join(lines)
from sqlalchemy import create_engine, and_
from sqlalchemy.orm import sessionmaker, scoped_session
import sys
import os
import numpy as np
import argparse
import threading
import pandas as pd
import re
import threading
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from data.domain import *
from cls_utils.sitk_utils import CTSeries
from cls_utils.data_utils import crop_ct_data, get_crop_data_padding, get_crop_data_2d
from cls_utils.utils import hu_value_to_uint8, normalize, base64_to_list
from cls_utils.data import save_supplement_data_csv, save_data_to_npy, load_npy_to_data, create_cls_train_csv, \
load_all_dicom_file, load_json, create_cls_train_all_csv, create_cls_train_csv_3d, \
replace_label_ids, add_label_ids, create_cls_train_last_3d
MYSQL_SERVER = 'mysql+pymysql://lung:lung1qaz2wsx@127.0.0.1:3306/ct_file?charset=utf8'
# cfg = load_json("/home/lung/ai-project/cls_train/config/train.json")
# cfg = load_json("/df_lung/ai-project/cls_train/config/train.json")
cfg = load_json("/df_lung/ai-project/cls_train/config/train_20241112.json")
"""
连接数据库,返回一个session
"""
def conect_mysql():
engine = create_engine(MYSQL_SERVER, pool_recycle=3600)
#onnection = engine.connect()
db_session = sessionmaker(bind=engine)
session = scoped_session(db_session)
return session
def select_series(session, select_node_time=0, start_label_id=None):
"""
Describe:
在打标签的数据中根据对应的series_instance_uid查找该文件对应的文件夹名
Returns:
folder_name: 当前标签数据所在dicom文件对应的文件夹
node_time: 当前标签数据对应的label
box_info: 标签数据的平面坐标信息
select_box: shape=(n, 3, 2) 数据存储为[[z_min, z_max], [y_min, y_max], [x_min, x_max]]
"""
#通过内连接,直接将两个表连接起来进行数据查询
#如果select_node_time=0,表示要查找的是数据库中所有打了类别标签的数据,否则就是查找指定node_time == select_node_time的数据
if select_node_time == 0:
userlabel_and_dicomseries = session.query(UserLabel, DicomSeries).filter(
and_(UserLabel.node_time != None,
UserLabel.deleted_time == None,
UserLabel.node_time != select_node_time,
UserLabel.series_id == DicomSeries.id))
else:
userlabel_and_dicomseries = session.query(UserLabel, DicomSeries).filter(
and_(UserLabel.node_time != None,
UserLabel.deleted_time == None,
UserLabel.node_time == select_node_time,
UserLabel.series_id == DicomSeries.id))
userlabel_and_dicomseries_list = list(map(lambda x: x, userlabel_and_dicomseries))
label_ids = []
box_infos = []
z_index_ranges = []
node_times = []
folder_names = []
select_boxs = []
patient_ids = []
series_instance_uids = []
for userlabel_dicomseries in userlabel_and_dicomseries_list:
label_ids.append(userlabel_dicomseries[0].id)
box_infos.append(list(map(lambda x: int(x), userlabel_dicomseries[0].box_info.strip('[]').split(','))))
z_index_ranges.append(list(map(lambda x: int(x.split(':')[0]), userlabel_dicomseries[0].area.strip('{}').split(','))))
node_times.append(userlabel_dicomseries[0].node_time)
folder_names.append(userlabel_dicomseries[1].folder_name)
patient_ids.append(userlabel_dicomseries[1].patient_id)
series_instance_uids.append(userlabel_dicomseries[1].series_instance_uid)
select_boxs.append([[min(z_index_ranges[-1]), max(z_index_ranges[-1])], [box_infos[-1][2], box_infos[-1][3]], [box_infos[-1][0], box_infos[-1][1]]])
select_boxs = np.array(select_boxs, np.float32)
return folder_names, node_times, select_boxs, box_infos, label_ids, patient_ids, series_instance_uids
#在user_label_delineation表中查询指定label_id的所有数据
def select_all_contours_by_labelId(session, label_id):
delineations = session.query(UserLabelDelineation).filter(
and_(UserLabelDelineation.label_id == label_id,)).order_by(UserLabelDelineation.z_index)
contours = list(map(lambda x: x.contour, delineations))
return contours
#从数据库中获取指定node_time的全部数据,并获取其对应的全部label_id
def select_series_by_node_time(node_time=0):
session = conect_mysql()
folder_names, _, select_boxs, _, label_ids, patient_ids, series_instance_uids= select_series(session, node_time)
return folder_names, select_boxs, label_ids, patient_ids, series_instance_uids
#从数据库中只读取一个数据,主要是为了进行预测
def select_signal_series(label_id):
session = conect_mysql()
userlabel_and_dicomseries = session.query(UserLabel, DicomSeries).filter(
and_(UserLabel.id == label_id,
UserLabel.series_id == DicomSeries.id))
userlabel_and_dicomseries = list(map(lambda x: x, userlabel_and_dicomseries))
box_infos = list(map(lambda x: int(x), userlabel_and_dicomseries[0][0].box_info.strip('[]').split(',')))
z_index_range = list(map(lambda x: int(x.split(':')[0]), userlabel_and_dicomseries[0][0].area.strip('{}').split(',')))
z_min = min(z_index_range)
z_max = max(z_index_range)
folder_name = userlabel_and_dicomseries[0][1].folder_name
series_instance_uid = userlabel_and_dicomseries[0][1].series_instance_uid
select_box = [[z_min, z_max],
[box_infos[2], box_infos[3]],
[box_infos[0], box_infos[1]]]
patient_id = userlabel_and_dicomseries[0][1].patient_id
return folder_name, select_box, patient_id, series_instance_uid
def read_series_dicom(dicom_folder=''):
ct = CTSeries()
ct.load_dicoms(dicom_folder)
return ct
#处理单个样本数据,如果mode=None,则将节点每个切面都拿出来当作中心面,在上下填充背景值,否则就直接从原始数据中直接扣取数据
#is_2d=True则只将当前切面数据进行剪切
def process_single_ct(ct_data=None, select_box=None, node_time=None, label_id=None, mode=None, is_2d=False):
#dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
#通过folder_name和label_id来判断是否已经存在该文件,如果存在就不进行创建操作
npy_output_path = os.path.join(cfg['train_data_path'], cfg['npy_folder'], 'cls_' + str(node_time), str(label_id)+'.npy')
csv_output_path = os.path.join(cfg['train_data_path'], cfg['csv_path'], cfg['subject_all_csv'])
if os.path.exists(npy_output_path):
print(f"npy_output_path exists: {npy_output_path}")
return
#ct_data = read_series_dicom(dicom_folder=dicom_folder)
#如果采用掩码进行处理则对ct_data数据进行处理
if mode is None:
if is_2d:
original_data = get_crop_data_2d(data=ct_data, select_box=select_box, crop_size=cfg['train_crop_size_2d'])
else:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=select_box, crop_size=cfg['train_crop_size'])
else:
#这里是在select_box的基础上从原始数据中进行补充,如果原始数据中上层和下层都还有可以补充的就直接拿过来,否则才补充-1000层
original_data = crop_ct_data(ct_data=ct_data, select_box=select_box, crop_size=cfg['train_crop_size'])
#将数据保存
save_data_to_npy(original_data=original_data, output_file=npy_output_path)
content = ['cls_' + str(node_time)+'/'+str(label_id)+'.npy', str(node_time)]
#将该信息添加到总的数据集的csv文件内
save_supplement_data_csv(content=content, output_file=csv_output_path)
#将数据每一层都拿出来,将其作为中心面,填充成(48, 256, 256)的数据,当seg=True表示使用分割之后的结果,只保留节点,其余像素都设置为背景
#如果mode=None,则表示将该节点分割出来的每一面都当作中心面进行处理
#is_2d=True将每个切面结节所在区域保存为[256, 256]大小
def process_single_ct_all_series(contours=None, folder_name=None, select_box=None, node_time=None,
label_id=None, mode=None, seg=False, is_2d=False, threshold=0):
z_min = int(select_box[0][0])
z_max = int(select_box[0][1])
profile_id = str(label_id) + '_'
print("label_id: ", label_id)
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
#直接将分割之后的图像数据进行结节抽取
if mode is not None:
process_single_ct(ct_data=ct_data, select_box=select_box, node_time=node_time, label_id=label_id, mode=mode, is_2d=False)
else:
#计算出每个结节的两端结节,将其排除调,设定一个阈值进行排除
eliminate_num = int((threshold * (z_max - z_min + 1)) / 2)
#将结节所在的每个切面都当作中心面进行处理
for z_index in range(z_min + eliminate_num, z_max - eliminate_num + 1):
label_id = profile_id + str(z_index)
select_box[0] = [z_index, z_index]
print('文件名:', label_id)
#处理单个select_box数据
if is_2d:
data_2d = data[z_index]
process_single_ct(ct_data=data_2d, select_box=select_box, node_time=node_time, label_id=label_id, mode=mode, is_2d=True)
else:
process_single_ct(ct_data=ct_data, select_box=select_box, node_time=node_time, label_id=label_id, mode=mode)
def get_all_contours_by_labelId(label_id=None):
session = conect_mysql()
contours = select_all_contours_by_labelId(session, label_id=label_id)
return contours
def process_cts(seg=False, node_time=0, start_label_id=None, mode=None, is_2d=False, threshold=0):
"""
mode=None表示将将每个切面都当作中心面,其余层都进行背景填充
seg=True表示通过分割掩码只保留结节部分,别的都填充为背景值-1000
is_2d=true则将数据大小为[256, 256]
threshold=0表示剔除现有结节层数的比例
"""
session = conect_mysql()
folder_names, node_times, select_boxs, _, label_ids, patient_ids, series_instance_uids = select_series(session, select_node_time=node_time)
for i in range(len(folder_names)):
#将指定label_id之后的数据生成数据集
# if label_ids[i] > start_label_id:
#if label_ids[i] > start_label_id:
if label_ids[i] in [5695, "5695"]:
#从数据库中获取指定label_id的掩码信息
contours = select_all_contours_by_labelId(session, label_ids[i]) if seg else None
folder_name = str(patient_ids[i])+'-'+str(series_instance_uids[i])
process_single_ct_all_series(contours=contours, folder_name=folder_name, select_box=select_boxs[i],
node_time=node_times[i], label_id=label_ids[i], mode='3d', seg=True, is_2d=is_2d, threshold=threshold)
#将数据库中获取到数据集,并保存到相应的文件内
def run():
"""
连接数据库,从数据库中获取到所有的要作为数据集的信息,
包含了所有的文件名:folder_name、节点类别:node_time、节点所在的空间坐标:select_boxs、当前节点在数据库中所对应的id:label_id
"""
session = conect_mysql()
folder_names, node_times, select_boxs, box_infos, label_ids = select_series(session)
#测试一个数据是否正确
#process_single_ct_all_series(folder_name=folder_names[0], select_box=select_boxs[0], node_time=node_times[0], label_id=label_ids[0])
#将每个切面数据都单独拿出来,然后进行填充生成对应的npy文件
for i in range(len(folder_names)):
#将指定label_id之后的数据生成训练集
if label_ids[i] > 517:
#从数据库中获取掩码信息
#process_single_ct_all_series(folder_name=folder_names[i], select_box=select_boxs[i], node_time=node_times[i], label_id=label_ids[i])
process_single_ct(folder_name=folder_names[i], select_box=select_boxs[i], node_time=node_times[i], label_id=label_ids[i], mode=1)
#依次对每个数据都进行处理,提取数据并保存到相应的文件中
print('完成')
#生成训练数据对应的train_csv
def creat_train_csv():
"""
AAH1: 2011
AIS1: 2021
MIA1: 2031
IAC1: 2041
炎症: 1010
高密度炎症1: 1016
增生: 1020
高密度IAC1: 2046
[AAH1, MIA1, IAC1]: 2_134
[炎症, 高密度炎症1, 增生]: 1_112
"""
#[2041, 3001, 4001, 5001, 6001, 7006, 2060, 2061, 2062]
# node_times_1 = [[2011, 2021, 2041]]
node_times_2 = [[2041]]
node_times_1 = [[2041]]
csv_path = os.path.join(cfg['train_data_path'], cfg['csv_path'])
#pretrain_csv_path = '/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/test/cls_1_5001-6001_1/train.csv'
for time_1 in node_times_1:
for time_2 in node_times_2:
#create_cls_train_csv(node_times=[time_1, time_2], node2='', csv_path=csv_path, csv_name=cfg['subject_all_csv'], tabel_id='01_2', node1_end=False, node2_end=False)
#create_cls_train_all_csv(node_times=[time_1, time_2], csv_path=csv_path, csv_name=cfg['subject_all_csv'], tabel_id='10_3')
# create_cls_train_csv_3d(node_times=[time_1, time_2], csv_path=csv_path, csv_name=cfg['subject_all_csv'], tabel_id='08')
create_cls_train_csv_3d(node_times=[time_1, time_2], csv_path=csv_path, csv_name=cfg['subject_all_csv'], tabel_id='08')
#create_cls_train_last_3d(node_times=[time_1, time_2], csv_path=csv_path, csv_name=cfg['subject_all_csv'], tabel_id='test', pretrain_csv_path=pretrain_csv_path)
#node_times = [[1080], [2048]]
#csv_path = os.path.join(cfg['train_data_path'], cfg['csv_path'])
#create_cls_train_csv(node_times=node_times, node2='', csv_path=csv_path, csv_name=cfg['subject_all_csv'], tabel_id='06', node1_end=False, node2_end=False)
#将一个log文件中出错的label_id找出来
def extract_error_label(log_path, positive=True):
with open(log_path, 'r') as f:
log_contents = f.readlines()
label_ids = []
#检索每一行数据,找到每个结节最后一行的均值结果
for line in log_contents:
match = re.search(r"label_id: (.*?), result: \[(.*?)\]\n", line)
if match:
label_id, result = match.group(1), match.group(2)
result = float(result)
if positive and result < 0.5:
label_ids.append(label_id)
elif positive is False and result > 0.5:
label_ids.append(label_id)
return label_ids
def test_read_npy():
#随机读取一个npy文件数据,观察是否其值的大小,最后将其显示出来
#npy_path = 'D:\\vscode\\plus代码\\ct_plus_seg_python\\cls_train\\data\\npy_data\\cls_1010\\379.npy'
npy_path = './cls_train/data/npy_data/cls_1010'
#遍历当前目录下的所有npy文件,并将其对应的最大值都进行输出
all_dicom_file = load_all_dicom_file(npy_path, prefix='*', postfix='npy')
#将全部的npy文件都读进来,然后统计最大值和最小值
for file_path in all_dicom_file:
original_data = load_npy_to_data(file_path)
print(original_data.shape)
#print(original_data[0])
image = hu_value_to_uint8(original_data=original_data)
data = normalize(image=image)
#print(image)
#print(data)
print(np.max(original_data), np.max(image), np.max(data))
print(np.min(original_data), np.min(image), np.min(data))
def test_single_dicom():
'''
测试出错的那个单例
'''
folder_name = 'CT00100632-1.2.840.113704.1.111.2748.1295244861.11'
select_box = [[79, 98],
[669, 784],
[452, 561]]
node_time = 2040
label_id = 359
select_box = np.array(select_box, np.float32)
process_single_ct(folder_name=folder_name, select_box=select_box, node_time=node_time, label_id=label_id)
def test_count_csv():
csv_path = '/home/lung/project/ai-project/cls_train/data/train_data/plus_0815/subject_all_csv/02/cls_1010_2021/cls_1010_2021_1/train.csv'
df = pd.read_csv(csv_path, names=['id', 'label'])
select_rows = df[df['label'] == 1].drop_duplicates()
print(len(select_rows))
if __name__ == "__main__":
#通过多线程生成数据
"""node_times = [2031]
threads = []
for node_time in node_times:
thread = threading.Thread(target=process_cts, args=(True, node_time, 427))
#process_cts(seg=True, node_time=node_time, start_label_id=427, is_2d=False)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()"""
process_cts(seg=True, node_time=[2041], start_label_id=427, is_2d=False)
# creat_train_csv()
# test_count_csv()
"""log_path = '/home/lung/ai-project/cls_train/log/validation/cls_1_2031/20240720/log_validation_1.log'
positive = True
label_ids = extract_error_label(log_path, positive=positive)
print(len(label_ids))
csv_path = os.path.join(cfg['train_data_path'], cfg['csv_path'])
# #replace_label_ids(label_ids=label_ids, csv_path=csv_path, tabel_id='test')
add_label_ids(label_ids=label_ids, csv_path=csv_path, positive=positive, cls_name='cls_1_2031_02')"""
\ No newline at end of file
import sys, os
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
from cls_utils.log_utils import get_logger
from sqlalchemy import create_engine, and_
from sqlalchemy.orm import sessionmaker, scoped_session
import sys
import os
import numpy as np
import argparse
import threading
from tqdm import tqdm
import pandas as pd
from datetime import datetime
import json
import re
import threading
from pathlib import Path
import scipy
import SimpleITK as sitk
from joblib import Parallel, delayed
from scipy.ndimage import rotate as scipy_rotate
import torch
import torchio as tio
import torch.nn.functional as F
from multiprocessing import Process
from torchvision import transforms as T
import cupy as cp
import radiomics
from cupyx.scipy.ndimage import rotate as cupy_rotate
from data.domain import DicomStudy, PatientInfo, UserLabel, UserLabelDelineation, DicomSeries
from data.data_process_utils.test_sitk_utils import CTSeries, base64_to_list, meta_to_list
from PIL import Image
from sklearn.preprocessing import StandardScaler
from collections import defaultdict
logger = get_logger(log_file="/df_lung/ai-project/cls_train/log/data/get_db_data_to_feat.log")
# from cls_utils.sitk_utils import CTSeries
# from cls_utils.data_utils import crop_ct_data, get_crop_data_padding, get_crop_data_2d
# from cls_utils.utils import hu_value_to_uint8, normalize, base64_to_list
# from cls_utils.data import save_supplement_data_csv, save_data_to_npy, load_npy_to_data, create_cls_train_csv, \
# load_all_dicom_file, load_json, create_cls_train_all_csv, create_cls_train_csv_3d, \
# replace_label_ids, add_label_ids, create_cls_train_last_3d
MYSQL_SERVER = 'mysql+pymysql://lung:lung1qaz2wsx@127.0.0.1:3306/ct_file?charset=utf8'
"""
连接数据库,返回一个session
"""
def conect_mysql():
engine = create_engine(MYSQL_SERVER, pool_recycle=3600)
#onnection = engine.connect()
db_session = sessionmaker(bind=engine)
session = scoped_session(db_session)
return session
def get_cts(dicom_path=None):
cts = CTSeries()
cts.load_dicoms(dicom_path)
return cts
def generate_node_all_label_id_df(node_time=None):
'''
查询条件:
1、系统显示
dicom_file_study.status != 5
patient_info.status != 1
2、标注状态是正常
user_label.status != 1
关联查询
user_label.study_id = dicom_file_study.id
# user_label.pid == Null
dicom_file_study.patient_info_id = patient_info.id
查询步骤:
1、先查询user_label所有数据,再过滤
2、根据dicom_file_study、patient_info 筛选数据
3、筛选条件:
user_label.study_id = dicom_file_study.id
dicom_file_study.patient_info_id = patient_info.id
dicom_file_study.status != 5
patient_info.status != 1
user_label.status != 1
user_label.deleted_time == None
user_label.node_time == node_time
返回值:
label_ids: 所有label_id
'''
if node_time is None:
return None
session = conect_mysql()
logger.info(f"start query")
query = session.query(
UserLabel.node_time,
UserLabel.id,
PatientInfo.patient_id,
UserLabel.study_id,
UserLabel.series_id,
DicomStudy.study_uid,
DicomStudy.folder_name,
DicomSeries.series_instance_uid
).join(
DicomStudy, UserLabel.study_id == DicomStudy.id
).join(
PatientInfo, DicomStudy.patient_info_id == PatientInfo.id
).join(
DicomSeries, UserLabel.series_id == DicomSeries.id
).filter(
and_(
DicomStudy.status != 5,
PatientInfo.status != 1,
UserLabel.status != 1,
UserLabel.deleted_time == None,
UserLabel.node_time == node_time
)
)
result = query.all()
node_times = [row[0] for row in result]
label_ids = [row[1] for row in result]
patient_ids = [row[2] for row in result]
study_ids = [row[3] for row in result]
series_ids = [row[4] for row in result]
study_uids = [row[5] for row in result]
folder_names = [row[6] for row in result]
series_instance_uids = [row[7] for row in result]
session.close()
df = pd.DataFrame({'node_time': node_times, 'label_id': label_ids, 'patient_id': patient_ids, 'study_id': study_ids, 'series_id': series_ids, 'study_uid': study_uids, 'folder_name': folder_names, 'series_instance_uid': series_instance_uids})
df["patient_id"] = df["patient_id"].astype(str)
df["study_id"] = df["study_id"].astype(str)
df["series_id"] = df["series_id"].astype(str)
df["study_uid"] = df["study_uid"].astype(str)
df["folder_name"] = df["folder_name"].astype(str)
df["series_instance_uid"] = df["series_instance_uid"].astype(str)
return df
def select_single_label_id(label_id=None):
session = conect_mysql()
label = session.query(UserLabel).filter(
and_(UserLabel.id == label_id)).first()
if label is None:
return None, f"{label_id}, 标注数据不存在"
node_time = label.node_time
bundle = session.query(DicomSeries).filter(
and_(DicomSeries.id == label.series_id)).first()
if bundle is None:
return None, f"{label_id}, 关联的dicom数据不存在"
delineations = session.query(UserLabelDelineation).filter(
and_(UserLabelDelineation.label_id == label_id, UserLabelDelineation.status == 0)).order_by(
UserLabelDelineation.z_index.asc()).all()
session.close()
return (label, bundle, delineations), "success"
def generate_single_series_raw_data_3d_by_label_id(label_id=None, dicom_folder=""):
(label, bundle, delineations), result = select_single_label_id(label_id=label_id)
patient_id = bundle.patient_id
series_instance_uid = bundle.series_instance_uid
data, selected_box, node_time, raw_cood_3d, z_index_list_3d = None, None, None, None, None
if result != "success":
return data, selected_box, node_time, patient_id, series_instance_uid, raw_cood_3d, z_index_list_3d
dicom_path = f"{dicom_folder}/{patient_id}-{series_instance_uid}"
cts = get_cts(dicom_path)
data = cts.get_raw_data()
spacing = cts.get_raw_spacing()
mask = np.zeros((len(data), len(data[1]), len(data[1][1])), np.uint8)
node_time = label.node_time
z_count = 0
for delineation in delineations:
if (delineation.contour is None or len(delineation.contour) == 0) and delineation.meta is None:
continue
indexlist, indexs, img_np = base64_to_list(delineation.contour)
if delineation.contour is None and delineation.meta:
indexlist, indexs, img_np = meta_to_list(delineation.meta, mask[0].copy())
mask[delineation.z_index] = img_np
z_count += 1
if mask is not None and np.sum(mask == 1) > 0:
coords = np.asarray(np.where(mask == 1))
selected_box = np.zeros((3, 2), np.float32)
selected_box[:, 0] = coords.min(axis=1)
selected_box[:, 1] = coords.max(axis=1) + 1
select_box_z_count = selected_box[0][1] - selected_box[0][0]
logger.info(f"selected_box: {selected_box}, select_box_z_count: {select_box_z_count}, z_count: {z_count}, data: {data.shape}, cood: {mask.shape}")
raw_cood_3d = mask
z_index_list_3d = list(range(int(selected_box[0][0]), int(selected_box[0][0]) + z_count))
return data, selected_box, node_time, patient_id, series_instance_uid, raw_cood_3d, z_index_list_3d
def generate_raw_data_3d_npy_data_by_single_label_id(label_id=None, dicom_folder="", save_path="", is_save_flag=False):
patient_id, series_instance_uid, raw_data_3d_npy_file, raw_cood_3d_npy_file, z_index_list_3d = None, None, None, None, None
if label_id is None:
return patient_id, series_instance_uid, raw_data_3d_npy_file, raw_cood_3d_npy_file, z_index_list_3d
data, selected_box, node_time, patient_id, series_instance_uid, raw_cood_3d, z_index_list_3d = generate_single_series_raw_data_3d_by_label_id(label_id=label_id, dicom_folder=dicom_folder)
if not os.path.exists(save_path):
Path(save_path).mkdir(parents=True, exist_ok=True)
raw_data_3d_npy_file = f"{save_path}/{node_time}_{label_id}_raw_data_3d.npy"
raw_cood_3d_npy_file = f"{save_path}/{node_time}_{label_id}_raw_cood_3d.npy"
if is_save_flag:
np.save(raw_data_3d_npy_file, data)
np.save(raw_cood_3d_npy_file, raw_cood_3d)
logger.info(f"save 3d npy data -> {raw_data_3d_npy_file}\n{raw_cood_3d_npy_file}")
return patient_id, series_instance_uid, raw_data_3d_npy_file, raw_cood_3d_npy_file, z_index_list_3d
def generate_raw_data_3d_npy_data_by_all_label_id_df(csv_file=None, node_raw_data_3d_npy_file=None, dicom_folder="/opt/lung/ai", save_path="", is_save_flag=False):
node_df = pd.read_csv(csv_file)
data_3d_node_list = []
data_3d_label_id_list = []
raw_data_3d_npy_file_list = []
raw_cood_3d_npy_file_list = []
data_3d_z_index_list = []
data_3d_patient_id_list = []
data_3d_series_instance_uid_list = []
for idx in tqdm(range(len(node_df))):
node_time = node_df.loc[idx, 'node_time']
label_id = node_df.loc[idx, 'label_id']
idx_patient_id = node_df.loc[idx, 'patient_id']
idx_series_instance_uid = node_df.loc[idx, 'series_instance_uid']
patient_id, series_instance_uid, raw_data_3d_npy_file, raw_cood_3d_npy_file, z_index_list_3d = generate_raw_data_3d_npy_data_by_single_label_id(
label_id=label_id,
dicom_folder=dicom_folder,
save_path=save_path,
is_save_flag=is_save_flag
)
if patient_id is None and series_instance_uid is None:
raise Exception(f"node_df, idx: {idx}, is None, {node_df[idx:idx+1]}")
data_3d_node_list += [node_time]
data_3d_label_id_list += [label_id]
raw_data_3d_npy_file_list += [raw_data_3d_npy_file]
raw_cood_3d_npy_file_list += [raw_cood_3d_npy_file]
data_3d_z_index_list += [[z_index_list_3d]]
data_3d_patient_id_list += [patient_id]
data_3d_series_instance_uid_list += [series_instance_uid]
raw_data_3d_df = pd.DataFrame({
'node': data_3d_node_list,
'label_id': data_3d_label_id_list,
'patient_id': data_3d_patient_id_list,
'raw_data_3d_npy_file': raw_data_3d_npy_file_list,
'raw_cood_3d_npy_file': raw_cood_3d_npy_file_list,
'z_index': data_3d_z_index_list,
'series_instance_uid': data_3d_series_instance_uid_list
})
if is_save_flag:
raw_data_3d_df.to_csv(node_raw_data_3d_npy_file, index=False, encoding="utf-8")
logger.info(f"generate raw data_3d, save to {node_raw_data_3d_npy_file}")
return
def get_feature_excutor():
config = {
'binWidth': 25,
'resampledPixelSpacing': None,
'normalize': True,
'normalizeScale': 1.0,
'featureClass': {
'shape': True,
'firstorder': True,
'glcm': True,
'glrlm': True,
'glszm': True,
'ngtdm': True,
'gldm': True
}
}
extractor = radiomics.featureextractor.RadiomicsFeatureExtractor(**config)
extractor.settings['normalize'] = True
extractor.settings['normalizeScale'] = 1.0
return extractor
def process_features(args):
csv_file = args[0]
node_feature_file = args[1]
is_save_flag = args[2]
df = pd.read_csv(csv_file, header=0, encoding = 'utf-8')
from collections import OrderedDict
def parse_feat(feature):
result = OrderedDict()
for idx_feature, idx_score in feature.items():
if isinstance(idx_score, np.ndarray):
idx_score = idx_score.tolist()
elif isinstance(idx_score, (np.float16, np.float32, np.float64)):
idx_score = float(idx_score)
elif isinstance(idx_score, (np.int32, np.int64)):
idx_score = int(idx_score)
result[idx_feature] = idx_score
return result
node_list = []
label_id_list = []
patient_id_list = []
feature_list = []
extractor = get_feature_excutor()
for idx in range(len(df)):
idx_node = df.loc[idx, 'node']
idx_label_id = df.loc[idx, 'label_id']
idx_patient_id = df.loc[idx, 'patient_id']
idx_data_3d_npy_file = df.loc[idx, 'raw_data_3d_npy_file']
idx_cood_3d_npy_file = df.loc[idx, 'raw_cood_3d_npy_file']
idx_data_3d = np.load(idx_data_3d_npy_file)
idx_cood = np.load(idx_cood_3d_npy_file)
idx_data_3d_sitk = sitk.GetImageFromArray(idx_data_3d)
idx_cood_sitk = sitk.GetImageFromArray(idx_cood.astype(np.uint8))
idx_feature = extractor.execute(idx_data_3d_sitk, idx_cood_sitk)
idx_feature = parse_feat(idx_feature)
idx_feature = json.dumps(idx_feature)
node_list += [idx_node]
label_id_list += [idx_label_id]
patient_id_list += [idx_patient_id]
feature_list += [idx_feature]
node_feature_df = pd.DataFrame(
{
'node': node_list,
'label_id': label_id_list,
'patient_id': patient_id_list,
'feature': feature_list
}
)
if is_save_flag:
node_feature_df.to_csv(node_feature_file, index=False, encoding="utf-8")
logger.info(f"process node feature, save to node_feature_file: {node_feature_file}")
def get_data_3d_feature(node_csv_file_list=None, is_save_flag=False):
process_args_list = []
for node_time, csv_file in node_csv_file_list:
idx_node_feature_file = f"{csv_file.replace('.csv', f'_feature')}.csv"
process_args_list.append(
[
csv_file,
idx_node_feature_file,
is_save_flag
]
)
process_count = len(process_args_list)
process_list = []
for idx in range(process_count):
process_args = process_args_list[idx]
idx_process = Process(target=process_features, args=(process_args,))
idx_process.start()
process_list.append(idx_process)
for idx_process in process_list:
idx_process.join()
return
def get_node_time_all_label_ids_df(node_time=None, csv_data_dir=""):
if node_time is None:
return None
df = generate_node_all_label_id_df(node_time=node_time)
task_info = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_file = f"{csv_data_dir}/{node_time}/{node_time}_{task_info}_rotate_10.csv"
Path(csv_file).parent.mkdir(parents=True, exist_ok=True)
df.to_csv(csv_file, index=False, encoding="utf-8")
logger.info(f"save csv data -> {csv_file}")
return csv_file
def generate_feature_train_npy_csv_file(node_csv_file_list=None, csv_data_dir="", train_csv_dir="", train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, is_pad_df=True, is_save_csv=False, seed=100004):
node_csv_file_dict = {
idx[0]: idx[1]
for idx in node_csv_file_list
}
node_list = list(node_csv_file_dict.keys())
from sklearn.model_selection import train_test_split
def pad_df(df, max_len):
if len(df) == max_len:
return df
elif len(df) > max_len:
return df[:max_len]
else:
pad_df_list = [df]
lens = len(df)
while lens < max_len:
pad_df_list.append(df)
lens += len(df)
pad_df = pd.concat(pad_df_list, ignore_index=True)
return pad_df[:max_len]
def get_expand_feature(df=None, is_norm_flag=False):
expanded_data = defaultdict(list)
for index, row in df.iterrows():
try:
feature_dict = json.loads(row['feature'])
except json.JSONDecodeError:
print(f"Failed to decode JSON at index {index}")
continue
for key, value in feature_dict.items():
if isinstance(value, list):
for i, item in enumerate(value):
col_name = f"{key}_{i}"
expanded_data[col_name].append(item)
else:
expanded_data[key].append(value)
expanded_df = pd.DataFrame(expanded_data)
expanded_df.index = df.index
df = pd.concat([df.drop(columns=['feature']), expanded_df], axis=1)
filter_col_list = [
'node', 'label_id', 'patient_id', 'class',
'diagnostics_Versions_PyRadiomics',
'diagnostics_Versions_Numpy',
'diagnostics_Versions_SimpleITK',
'diagnostics_Versions_PyWavelet',
'diagnostics_Versions_Python',
'diagnostics_Configuration_Settings',
'diagnostics_Configuration_EnabledImageTypes',
'diagnostics_Image-original_Hash',
'diagnostics_Image-original_Dimensionality',
'diagnostics_Mask-original_Hash',
]
feature_list = df.columns.tolist()
filter_col_list = list(set(feature_list)-set(filter_col_list))
if is_norm_flag:
scaler = StandardScaler()
selected_features = df[filter_col_list]
normalized_features = scaler.fit_transform(selected_features)
normalized_features_df = pd.DataFrame(normalized_features, columns=[f"{col}_zscore" for col in selected_features.columns])
df = pd.concat([df, normalized_features_df], axis=1)
return df
# 遍历正负node, 返回对应的训练数据
train_csv_file_list = []
print(f"node_list: {node_list}")
for idx_pos_node in node_list:
for idx_neg_node in node_list:
if idx_pos_node == idx_neg_node:
continue
idx_train_file = f"{idx_neg_node}_{idx_pos_node}_data_3d_feature_train.csv"
idx_val_file = f"{idx_neg_node}_{idx_pos_node}_data_3d_feature_val.csv"
idx_test_file = f"{idx_neg_node}_{idx_pos_node}_data_3d_feature_test.csv"
idx_train_file = os.path.join(train_csv_dir, idx_train_file)
idx_val_file = os.path.join(train_csv_dir, idx_val_file)
idx_test_file = os.path.join(train_csv_dir, idx_test_file)
idx_pos_df = pd.read_csv(node_csv_file_dict[idx_pos_node])
idx_neg_df = pd.read_csv(node_csv_file_dict[idx_neg_node])
idx_pos_df['class'] = 1
idx_neg_df['class'] = 0
lens_pos = len(idx_pos_df)
lens_neg = len(idx_neg_df)
if lens_pos < 3 or lens_neg < 3:
logger.info(f"{idx_pos_node}_{idx_neg_node} count < 3, skip")
continue
idx_df = pd.concat([idx_pos_df, idx_neg_df], ignore_index=True)
idx_df = get_expand_feature(df=idx_df, is_norm_flag=True)
idx_pos_df = idx_df[idx_df['class']==1]
idx_neg_df = idx_df[idx_df['class']==0]
idx_pos_df = idx_pos_df.reset_index(drop=True)
idx_neg_df = idx_neg_df.reset_index(drop=True)
idx_pos_train_df, idx_pos_test_val_df = train_test_split(idx_pos_df, test_size=1-train_ratio, random_state=seed)
idx_pos_val_df, idx_pos_test_df = train_test_split(idx_pos_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_neg_train_df, idx_neg_test_val_df = train_test_split(idx_neg_df, test_size=1-train_ratio, random_state=seed)
idx_neg_val_df, idx_neg_test_df = train_test_split(idx_neg_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
lens_pos_train = len(idx_pos_train_df)
lens_neg_train = len(idx_neg_train_df)
logger.info(f"generate_feature_train, before pos: {idx_pos_node}, {lens_pos_train}, {len(idx_pos_val_df)}, {len(idx_pos_test_df)}, neg: {idx_neg_node}, {lens_neg_train}, {len(idx_neg_val_df)}, {len(idx_neg_test_df)}")
if is_pad_df:
if lens_pos_train < lens_neg_train:
idx_pos_train_df = pad_df(idx_pos_train_df, lens_neg_train)
elif lens_pos_train > lens_neg_train:
idx_neg_train_df = pad_df(idx_neg_train_df, lens_pos_train)
logger.info(f"generate_feature_train, after pos: {idx_pos_node}, {len(idx_pos_train_df)}, {len(idx_pos_val_df)}, {len(idx_pos_test_df)}, neg: {idx_neg_node}, {len(idx_neg_train_df)}, {len(idx_neg_val_df)}, {len(idx_neg_test_df)}")
idx_train_df = pd.concat([idx_pos_train_df, idx_neg_train_df], ignore_index=True)
idx_val_df = pd.concat([idx_pos_val_df, idx_neg_val_df], ignore_index=True)
idx_test_df = pd.concat([idx_pos_test_df, idx_neg_test_df], ignore_index=True)
if is_save_csv:
idx_train_df.to_csv(idx_train_file, index=False, encoding='utf-8')
idx_val_df.to_csv(idx_val_file, index=False, encoding='utf-8')
idx_test_df.to_csv(idx_test_file, index=False, encoding='utf-8')
train_csv_file_list.append((idx_train_file, idx_val_file, idx_test_file))
logger.info(f"generate_feature_train, save to : {idx_train_file}\n{idx_val_file}\n{idx_test_file}")
print(f"训练集信息: {train_csv_file_list}")
return
csv_data_dir = "/df_lung/cls_train_data/csv_data"
npy_data_dir = "/df_lung/cls_train_data/npy_data"
def get_train_data_info_csv(node_time_list=[]):
for node_time in node_time_list:
csv_file = get_node_time_all_label_ids_df(node_time=node_time, csv_data_dir=csv_data_dir)
logger.info(f"{node_time}: {csv_file}\n")
def process_npy(args):
generate_raw_data_3d_npy_data_by_all_label_id_df(
csv_file=args[0],
node_raw_data_3d_npy_file=args[1],
dicom_folder=args[2],
save_path=args[3],
is_save_flag=args[4]
)
logger.info(f"process_raw_data_3d_npy finished: csv_file: {args[0]}")
def get_npy_data(node_csv_file_list=[], is_save_flag=False):
if False in [False for _ in node_csv_file_list if f"{_[0]}_" not in _[1]]:
raise ValueError(f"node_csv_file_list: {node_csv_file_list}")
process_args_list = []
for node_time, csv_file in node_csv_file_list:
save_path = f"{npy_data_dir}/{node_time}"
Path(save_path).mkdir(parents=True, exist_ok=True)
idx_node_raw_data_3d_npy_file = f"{csv_file.replace('.csv', f'_feat_raw_data_3d_npy')}.csv"
idx_node_raw_data_3d_npy_file = os.path.join(save_path, idx_node_raw_data_3d_npy_file)
process_args_list.append(
[
csv_file,
idx_node_raw_data_3d_npy_file,
"/opt/lung/ai",
save_path,
is_save_flag
]
)
process_count = len(process_args_list)
process_list = []
for idx in range(process_count):
process_args = process_args_list[idx]
idx_process = Process(target=process_npy, args=(process_args,))
idx_process.start()
process_list.append(idx_process)
for idx_process in process_list:
idx_process.join()
return
# # 生成csv数据
# node_time_list = [2046, 2047, 2048, 2060, 2061, 2062, 3001, 4001, 5001, 6001, 1016]
# get_train_data_info_csv(node_time_list=node_time_list)
'''
2021: /df_lung/cls_train_data/csv_data/2021/2021_20241204_094025_rotate_10.csv
2031: /df_lung/cls_train_data/csv_data/2031/2031_20241204_094025_rotate_10.csv
2041: /df_lung/cls_train_data/csv_data/2041/2041_20241204_094026_rotate_10.csv
1010: /df_lung/cls_train_data/csv_data/1010/1010_20241204_093726_rotate_10.csv
1020: /df_lung/cls_train_data/csv_data/1020/1020_20241204_093726_rotate_10.csv
2011: /df_lung/cls_train_data/csv_data/2011/2011_20241204_093726_rotate_10.csv
2046: /df_lung/cls_train_data/csv_data/2046/2046_20241211_155642_rotate_10.csv
2047: /df_lung/cls_train_data/csv_data/2047/2047_20241211_155642_rotate_10.csv
2048: /df_lung/cls_train_data/csv_data/2048/2048_20241211_155642_rotate_10.csv
2060: /df_lung/cls_train_data/csv_data/2060/2060_20241211_155643_rotate_10.csv
2061: /df_lung/cls_train_data/csv_data/2061/2061_20241211_155643_rotate_10.csv
2062: /df_lung/cls_train_data/csv_data/2062/2062_20241211_155643_rotate_10.csv
3001: /df_lung/cls_train_data/csv_data/3001/3001_20241211_155643_rotate_10.csv
4001: /df_lung/cls_train_data/csv_data/4001/4001_20241211_155643_rotate_10.csv
5001: /df_lung/cls_train_data/csv_data/5001/5001_20241211_155643_rotate_10.csv
6001: /df_lung/cls_train_data/csv_data/6001/6001_20241211_155643_rotate_10.csv
1016: /df_lung/cls_train_data/csv_data/1016/1016_20241211_155643_rotate_10.csv
'''
# # 生成npy数据
# node_csv_file_list = [
# (2021, "/df_lung/cls_train_data/csv_data/2021/2021_20241204_094025_rotate_10.csv"),
# (2031, "/df_lung/cls_train_data/csv_data/2031/2031_20241204_094025_rotate_10.csv"),
# (2041, "/df_lung/cls_train_data/csv_data/2041/2041_20241204_094026_rotate_10.csv"),
# (1010, "/df_lung/cls_train_data/csv_data/1010/1010_20241204_093726_rotate_10.csv"),
# (1020, "/df_lung/cls_train_data/csv_data/1020/1020_20241204_093726_rotate_10.csv"),
# (2011, "/df_lung/cls_train_data/csv_data/2011/2011_20241204_093726_rotate_10.csv"),
# (2046, "/df_lung/cls_train_data/csv_data/2046/2046_20241211_155642_rotate_10.csv"),
# (2047, "/df_lung/cls_train_data/csv_data/2047/2047_20241211_155642_rotate_10.csv"),
# (2048, "/df_lung/cls_train_data/csv_data/2048/2048_20241211_155642_rotate_10.csv"),
# (2060, "/df_lung/cls_train_data/csv_data/2060/2060_20241211_155643_rotate_10.csv"),
# (2061, "/df_lung/cls_train_data/csv_data/2061/2061_20241211_155643_rotate_10.csv"),
# (2062, "/df_lung/cls_train_data/csv_data/2062/2062_20241211_155643_rotate_10.csv"),
# (3001, "/df_lung/cls_train_data/csv_data/3001/3001_20241211_155643_rotate_10.csv"),
# (4001, "/df_lung/cls_train_data/csv_data/4001/4001_20241211_155643_rotate_10.csv"),
# (5001, "/df_lung/cls_train_data/csv_data/5001/5001_20241211_155643_rotate_10.csv"),
# (6001, "/df_lung/cls_train_data/csv_data/6001/6001_20241211_155643_rotate_10.csv"),
# (1016, "/df_lung/cls_train_data/csv_data/1016/1016_20241211_155643_rotate_10.csv")
# ]
# is_save_flag = True
# get_npy_data(node_csv_file_list=node_csv_file_list, is_save_flag=is_save_flag)
'''
'''
# # 影像组学特征
# node_csv_file_list = [
# (2021, "/df_lung/cls_train_data/csv_data/2021/2021_20241204_094025_rotate_10_feat_raw_data_3d_npy.csv"),
# (2031, "/df_lung/cls_train_data/csv_data/2031/2031_20241204_094025_rotate_10_feat_raw_data_3d_npy.csv"),
# (2041, "/df_lung/cls_train_data/csv_data/2041/2041_20241204_094026_rotate_10_feat_raw_data_3d_npy.csv"),
# (1010, "/df_lung/cls_train_data/csv_data/1010/1010_20241204_093726_rotate_10_feat_raw_data_3d_npy.csv"),
# (1020, "/df_lung/cls_train_data/csv_data/1020/1020_20241204_093726_rotate_10_feat_raw_data_3d_npy.csv"),
# (2011, "/df_lung/cls_train_data/csv_data/2011/2011_20241204_093726_rotate_10_feat_raw_data_3d_npy.csv"),
# (2046, "/df_lung/cls_train_data/csv_data/2046/2046_20241211_155642_rotate_10_feat_raw_data_3d_npy.csv"),
# (2047, "/df_lung/cls_train_data/csv_data/2047/2047_20241211_155642_rotate_10_feat_raw_data_3d_npy.csv"),
# (2048, "/df_lung/cls_train_data/csv_data/2048/2048_20241211_155642_rotate_10_feat_raw_data_3d_npy.csv"),
# (2060, "/df_lung/cls_train_data/csv_data/2060/2060_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (2061, "/df_lung/cls_train_data/csv_data/2061/2061_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (2062, "/df_lung/cls_train_data/csv_data/2062/2062_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (3001, "/df_lung/cls_train_data/csv_data/3001/3001_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (4001, "/df_lung/cls_train_data/csv_data/4001/4001_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (5001, "/df_lung/cls_train_data/csv_data/5001/5001_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (6001, "/df_lung/cls_train_data/csv_data/6001/6001_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv"),
# (1016, "/df_lung/cls_train_data/csv_data/1016/1016_20241211_155643_rotate_10_feat_raw_data_3d_npy.csv")
# ]
# is_save_flag = True
# get_data_3d_feature(node_csv_file_list=node_csv_file_list, is_save_flag=is_save_flag)
# 生成训练数据
node_csv_file_list = [
(2021, "/df_lung/cls_train_data/csv_data/2021/2021_20241204_094025_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2031, "/df_lung/cls_train_data/csv_data/2031/2031_20241204_094025_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2041, "/df_lung/cls_train_data/csv_data/2041/2041_20241204_094026_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(1010, "/df_lung/cls_train_data/csv_data/1010/1010_20241204_093726_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(1020, "/df_lung/cls_train_data/csv_data/1020/1020_20241204_093726_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2011, "/df_lung/cls_train_data/csv_data/2011/2011_20241204_093726_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2046, "/df_lung/cls_train_data/csv_data/2046/2046_20241211_155642_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2047, "/df_lung/cls_train_data/csv_data/2047/2047_20241211_155642_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2048, "/df_lung/cls_train_data/csv_data/2048/2048_20241211_155642_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2060, "/df_lung/cls_train_data/csv_data/2060/2060_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2061, "/df_lung/cls_train_data/csv_data/2061/2061_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(2062, "/df_lung/cls_train_data/csv_data/2062/2062_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(3001, "/df_lung/cls_train_data/csv_data/3001/3001_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(4001, "/df_lung/cls_train_data/csv_data/4001/4001_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(5001, "/df_lung/cls_train_data/csv_data/5001/5001_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(6001, "/df_lung/cls_train_data/csv_data/6001/6001_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv"),
(1016, "/df_lung/cls_train_data/csv_data/1016/1016_20241211_155643_rotate_10_feat_raw_data_3d_npy_feature.csv")
]
csv_data_dir = "/df_lung/cls_train_data/csv_data"
train_csv_dir = "/df_lung/cls_train_data/train_csv_data"
is_pad_df = True
is_save_csv = True
seed = 100004
generate_feature_train_npy_csv_file(
node_csv_file_list = node_csv_file_list,
csv_data_dir=csv_data_dir,
train_csv_dir=train_csv_dir,
is_pad_df=is_pad_df,
is_save_csv=is_save_csv,
seed=seed
)
import sys, os
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
from cls_utils.log_utils import get_logger
from sqlalchemy import create_engine, and_
from sqlalchemy.orm import sessionmaker, scoped_session
import sys
import os
import numpy as np
import argparse
import threading
from tqdm import tqdm
import pandas as pd
from datetime import datetime
import json
import re
import threading
from pathlib import Path
import scipy
from joblib import Parallel, delayed
from scipy.ndimage import rotate as scipy_rotate
import torch
import torchio as tio
import torch.nn.functional as F
from multiprocessing import Process
from torchvision import transforms as T
import cupy as cp
from cupyx.scipy.ndimage import rotate as cupy_rotate
from data.domain import DicomStudy, PatientInfo, UserLabel, UserLabelDelineation, DicomSeries
from data.data_process_utils.test_sitk_utils import CTSeries, base64_to_list, meta_to_list
from PIL import Image
logger = get_logger(log_file="/df_lung/ai-project/cls_train/log/data/get_db_data_to_npy.log")
# from cls_utils.sitk_utils import CTSeries
# from cls_utils.data_utils import crop_ct_data, get_crop_data_padding, get_crop_data_2d
# from cls_utils.utils import hu_value_to_uint8, normalize, base64_to_list
# from cls_utils.data import save_supplement_data_csv, save_data_to_npy, load_npy_to_data, create_cls_train_csv, \
# load_all_dicom_file, load_json, create_cls_train_all_csv, create_cls_train_csv_3d, \
# replace_label_ids, add_label_ids, create_cls_train_last_3d
MYSQL_SERVER = 'mysql+pymysql://lung:lung1qaz2wsx@127.0.0.1:3306/ct_file?charset=utf8'
"""
连接数据库,返回一个session
"""
def conect_mysql():
engine = create_engine(MYSQL_SERVER, pool_recycle=3600)
#onnection = engine.connect()
db_session = sessionmaker(bind=engine)
session = scoped_session(db_session)
return session
def get_cts(dicom_path=None):
cts = CTSeries()
cts.load_dicoms(dicom_path)
return cts
def rotate_dicom_scipy(data, num_rotations=10):
angle = 360 / num_rotations
rotated_data = []
for i in range(num_rotations):
rotated = scipy_rotate(data, angle * (i + 1), axes=(1, 2), reshape=False, order=3)
rotated_data.append(rotated)
return np.stack(rotated_data)
def rotate_dicom_torch(data, num_rotations=10):
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
angle = 360 / num_rotations
data_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device) # 转为5D张量 [N, C, D, H, W]
# 构建旋转矩阵和采样网格
def build_rotation_grid(shape, angle_deg):
angle_rad = torch.tensor(angle_deg * np.pi / 180, dtype=torch.float32, device=device)
cos_a = torch.cos(angle_rad)
sin_a = torch.sin(angle_rad)
# 3D旋转矩阵
rotation_matrix = torch.tensor([
[cos_a, -sin_a, 0, 0],
[sin_a, cos_a, 0, 0],
[0, 0, 1, 0],
], dtype=torch.float32, device=device)
d, h, w = shape
z, y, x = torch.meshgrid(
torch.linspace(-1, 1, d, device=device),
torch.linspace(-1, 1, h, device=device),
torch.linspace(-1, 1, w, device=device),
indexing='ij'
)
coords = torch.stack([x.flatten(), y.flatten(), z.flatten(), torch.ones_like(x.flatten())], dim=0)
rotated_coords = torch.matmul(rotation_matrix, coords).view(3, d, h, w)
return rotated_coords.permute(1, 2, 3, 0)
rotated_data = []
for i in range(num_rotations):
current_angle = angle * (i + 1)
grid = build_rotation_grid(data_tensor.shape[2:], current_angle)
rotated_volume = F.grid_sample(data_tensor, grid.unsqueeze(0), mode='bilinear', padding_mode='zeros', align_corners=True)
rotated_data.append(rotated_volume.squeeze(0).squeeze(0))
torch.cuda.empty_cache()
result = torch.stack(rotated_data).cpu().numpy()
del rotated_data, grid, rotated_volume, data_tensor, device, angle, current_angle, build_rotation_grid
torch.cuda.empty_cache()
return result
def rotate_dicom_cupy(data, num_rotations=10):
'''
rotated_data.append(rotated.get()) # 将数据移回 CPU
cp.get_default_memory_pool().free_all_blocks() # 释放显存
return np.stack(rotated_data)
'''
data_gpu = cp.asarray(data)
angle = 360 / num_rotations
rotated_data = []
for i in range(num_rotations):
rotated = cupy_rotate(data_gpu, angle * (i + 1), axes=(1, 2), reshape=False, order=1, prefilter=False)
rotated_data.append(rotated)
result = cp.stack(rotated_data).get()
del rotated_data, rotated, data_gpu, data
cp.get_default_memory_pool().free_all_blocks()
return result
def rotate_dicom_data(dicom_data=None, rotate_count=10, logger=None):
rotate_data = None
try:
rotate_data = rotate_dicom_cupy(dicom_data)
logger.info(f"rotate_dicom_cupy success")
except:
if logger is not None:
logger.error(f"rotate_dicom_cupy error, dicom_data: {dicom_data[0][:10]}")
if rotate_data is None:
try:
rotate_data = rotate_dicom_torch(dicom_data)
logger.info(f"rotate_dicom_torch success")
except:
if logger is not None:
logger.error(f"next rotate_dicom_torch error, dicom_data: {dicom_data[0][:10]}")
if rotate_data is None:
try:
rotate_data = rotate_dicom_scipy(dicom_data, rotate_count)
logger.info(f"rotate_dicom_scipy success")
except:
if logger is not None:
logger.error(f"next rotate_dicom_scipy error, dicom_data: {dicom_data[0][:10]}")
return rotate_data
def get_crop_start_end_index(center, lens):
center = int(center)
index_dict = {}
index_list = [center]
start = center - 1
while len(index_list) <= lens // 2:
index_list = [start] + index_list
start -= 1
start = center+1
while len(index_list) < lens:
index_list.append(start)
start += 1
for idx, idx_index in enumerate(index_list):
index_dict[int(idx_index)] = idx
return index_list, index_dict
def extend_crop_data_3d(data, select_box, crop_size=(48, 400, 400), fill_value=-1000, z_expand_flag=True, y_expand_flag=True, x_expand_flag=True, return_updated_box_flag=True):
"""
根据expand_flag,扩展、填充,同时返回更新后的 select_box。
Args:
data: 原始3D数据,形状 (D, H, W)。
select_box: 原始边界框
crop_size: 扩展尺寸。
fill_value: 超出边界部分的填充值。
Returns:
padded_data: 扩展后的数据
new_select_box: 更新后的边界框。
"""
select_box = select_box.astype(np.int32)
assert len(data.shape) == 3, "数据维度必须为3D"
assert len(select_box) == 3, "select_box必须包含3个轴的范围"
assert len(crop_size) == 3, "crop_size必须指定3个维度"
z_center = (select_box[0][0]+select_box[0][1]) // 2
y_center = (select_box[1][0]+select_box[1][1]) // 2
x_center = (select_box[2][0]+select_box[2][1]) // 2
crop_z_size, crop_y_size, crop_x_size = crop_size[0], crop_size[1], crop_size[2]
z_index_list = list(range(select_box[0][0],select_box[0][1]))
y_index_list = list(range(select_box[1][0],select_box[1][1]))
x_index_list = list(range(select_box[2][0],select_box[2][1]))
z_index_dict = dict(zip(z_index_list, z_index_list))
y_index_dict = dict(zip(y_index_list, y_index_list))
x_index_dict = dict(zip(x_index_list, x_index_list))
if z_expand_flag:
z_index_list, z_index_dict = get_crop_start_end_index(z_center, crop_z_size)
if y_expand_flag:
y_index_list, y_index_dict = get_crop_start_end_index(y_center, crop_y_size)
if x_expand_flag:
x_index_list, x_index_dict = get_crop_start_end_index(x_center, crop_x_size)
z_valid_start_index, z_valid_end_index = max(0, z_index_list[0]), min(data.shape[0]-1, z_index_list[-1])
y_valid_start_index, y_valid_end_index = max(0, y_index_list[0]), min(data.shape[1]-1, y_index_list[-1])
x_valid_start_index, x_valid_end_index = max(0, x_index_list[0]), min(data.shape[2]-1, x_index_list[-1])
padded_data = np.full(crop_size, fill_value, dtype=data.dtype)
z_padded_start_index, z_padded_end_index = z_index_dict[z_valid_start_index], z_index_dict[z_valid_end_index]
y_padded_start_index, y_padded_end_index = y_index_dict[y_valid_start_index], y_index_dict[y_valid_end_index]
x_padded_start_index, x_padded_end_index = x_index_dict[x_valid_start_index], x_index_dict[x_valid_end_index]
padded_data = np.full(crop_size, fill_value, dtype=data.dtype)
padded_data[
z_padded_start_index:z_padded_end_index + 1,
y_padded_start_index:y_padded_end_index + 1,
x_padded_start_index:x_padded_end_index + 1
] = data[
z_valid_start_index:z_valid_end_index + 1,
y_valid_start_index:y_valid_end_index + 1,
x_valid_start_index:x_valid_end_index + 1
]
assert np.all(padded_data[crop_size[0]//2, crop_size[1]//2, crop_size[2]//2] == data[z_center, y_center, x_center]), f"裁剪区域中心点,数值与原始数据中心点不一致"
updated_box_info_dict = None
if return_updated_box_flag:
updated_box_info_dict = {
"z_index_list": z_index_list,
"y_index_list": y_index_list,
"x_index_list": x_index_list,
"z_index_dict": z_index_dict,
"y_index_dict": y_index_dict,
"x_index_dict": x_index_dict
}
return padded_data, updated_box_info_dict
def data_rotation_3d(data, select_box, crop_size=(48, 256, 256), expand=40, num_rotations=10):
"""
增强3D数据
Args:
data: 原始3D图像,形状 (D, H, W)。
select_box: 初始肺结节的边界框。
crop_size: 裁剪尺寸。
expand: 扩展范围。
num_rotations: 旋转次数。
Returns:
final_data: 增强后的数据。
"""
first_crop_lens = int(np.sqrt(crop_size[1]**2 + crop_size[2]**2))
first_crop_expand_lens = first_crop_lens + expand + 1 if (first_crop_lens + expand)%2 == 0 else first_crop_lens + expand
first_crop_size = (crop_size[0], first_crop_expand_lens, first_crop_expand_lens)
logger.info(f"原始裁剪区域: {crop_size}, 扩展后裁剪区域: {first_crop_size}")
first_extended_data, first_updated_box_info_dict = extend_crop_data_3d(
data,
select_box,
crop_size=first_crop_size,
fill_value=-1000,
z_expand_flag=True,
y_expand_flag=True,
x_expand_flag=True,
return_updated_box_flag=True
)
rotated_data = rotate_dicom_data(dicom_data=first_extended_data, rotate_count=num_rotations)
logger.info(f"rotated_data: {rotated_data.shape}")
y_index_lens, x_index_lens = rotated_data.shape[2], rotated_data.shape[3]
y_center, x_center = y_index_lens // 2, x_index_lens // 2
y_start, y_end = y_center - crop_size[1] // 2, y_center + crop_size[1] // 2
x_start, x_end = x_center - crop_size[2] // 2, x_center + crop_size[2] // 2
result_data = rotated_data[:, :, y_start:y_end, x_start:x_end]
return result_data, first_updated_box_info_dict
def fill_data_default(input_data, fill_data=0.0001):
return np.full_like(input_data, fill_data)
def min_max_normalize_2d(input_data, min_value=-1000, max_value=800, fill_value=-1000, fill_data=0.0001):
'''
2d 数据归一化
'''
if np.all(input_data==fill_value):
return fill_data_default(input_data, fill_data)
corrected_data = np.copy(input_data)
corrected_data[corrected_data < min_value] = min_value
corrected_data[corrected_data > max_value] = max_value
corrected_data = corrected_data.astype(np.float32)
actual_min = min_value
actual_max = max_value
if actual_min == actual_max:
return fill_data_default(input_data, fill_data)
data = (corrected_data - actual_min) / (actual_max - actual_min)
return data
def min_max_normalize_2d_expand(input_data, min_value=-1000, max_value=800, fill_value=-1000, fill_data=0.0001):
'''
2d 扩展数据归一
'''
if np.all(input_data==fill_value):
return fill_data_default(input_data, fill_data)
corrected_data = np.copy(input_data)
corrected_data[corrected_data < min_value] = min_value
corrected_data[corrected_data > max_value] = max_value
corrected_data = corrected_data.astype(np.float32)
actual_min = min_value
actual_max = max_value
if actual_min == actual_max:
return fill_data_default(input_data, fill_data)
data = (corrected_data - actual_min) / (actual_max - actual_min)
return 2 * data - 1
def z_score_to_normalize_2d(input_data, fill_value=-1000, fill_data=0.0001):
'''
2d 数据归一
'''
if np.all(input_data==fill_value):
return fill_data_default(input_data, fill_data)
std = np.std(input_data)
if std == 0:
return fill_data_default(input_data, fill_data)
pt_data = torch.from_numpy(input_data).to(torch.float32)
image = tio.ScalarImage(tensor=pt_data.unsqueeze(0).unsqueeze(0))
znorm = tio.ZNormalization()
data = znorm(image).tensor.squeeze(0).squeeze(0).numpy()
return data
def z_score_T_normalize_2d(input_data, fill_value=-1000, fill_data=0.0001):
'''
2d 数据扩展
'''
if np.all(input_data==fill_value):
return fill_data_default(input_data, fill_data)
data = min_max_normalize_2d(input_data)
mean = np.mean(data)
std = np.std(data)
if std == 0:
return fill_data_default(input_data, fill_data)
normalize = T.Normalize(mean=[mean], std=[std])
data = torch.from_numpy(data).to(torch.float32).unsqueeze(0)
transform = T.Compose([
normalize
])
data = transform(data)
data = data.repeat(3, 1, 1).numpy()
return data
def normalize_net_2d(input_data, min_value=-1000, max_value=800):
'''
2d 网络归一
'''
x1,x2,x3,x4 = input_data.shape
data = np.zeros_like(input_data)
for idx in range(x1):
for jdx in range(x2):
data[idx, jdx] = z_score_to_normalize_2d(input_data[idx, jdx])
return data
def normalize_net_3d(input_data, min_value=-1000, max_value=800):
'''
3d网络归一
'''
x1,x2,x3,x4 = input_data.shape
data = np.zeros_like(input_data)
for idx in range(x1):
for jdx in range(x2):
data[idx, jdx] = z_score_to_normalize_2d(input_data[idx, jdx])
return data
def d2d_normalize(input_data, min_value=-1000, max_value=800):
'''
2d预训练归一
'''
x1,x2,x3,x4 = input_data.shape
x5 = 3
data = np.zeros((x1, x2, x5, x3, x4))
for idx in range(x1):
for jdx in range(x2):
data[idx, jdx] = z_score_T_normalize_2d(input_data[idx, jdx])
return data
def s3d_normalize_3d(input_data, min_value=-1000, max_value=800):
'''
3d预训练归一
'''
x1,x2,x3,x4 = input_data.shape
data = np.zeros_like(input_data)
for idx in range(x1):
for jdx in range(x2):
data[idx, jdx] = z_score_to_normalize_2d(input_data[idx, jdx])
return data
def generate_node_all_label_id_df(node_time=None):
'''
查询条件:
1、系统显示
dicom_file_study.status != 5
patient_info.status != 1
2、标注状态是正常
user_label.status != 1
关联查询
user_label.study_id = dicom_file_study.id
# user_label.pid == Null
dicom_file_study.patient_info_id = patient_info.id
查询步骤:
1、先查询user_label所有数据,再过滤
2、根据dicom_file_study、patient_info 筛选数据
3、筛选条件:
user_label.study_id = dicom_file_study.id
dicom_file_study.patient_info_id = patient_info.id
dicom_file_study.status != 5
patient_info.status != 1
user_label.status != 1
user_label.deleted_time == None
user_label.node_time == node_time
返回值:
label_ids: 所有label_id
'''
if node_time is None:
return None
session = conect_mysql()
logger.info(f"start query")
query = session.query(
UserLabel.node_time,
UserLabel.id,
PatientInfo.patient_id,
UserLabel.study_id,
UserLabel.series_id,
DicomStudy.study_uid,
DicomStudy.folder_name,
DicomSeries.series_instance_uid
).join(
DicomStudy, UserLabel.study_id == DicomStudy.id
).join(
PatientInfo, DicomStudy.patient_info_id == PatientInfo.id
).join(
DicomSeries, UserLabel.series_id == DicomSeries.id
).filter(
and_(
DicomStudy.status != 5,
PatientInfo.status != 1,
UserLabel.status != 1,
UserLabel.deleted_time == None,
UserLabel.node_time == node_time
)
)
result = query.all()
node_times = [row[0] for row in result]
label_ids = [row[1] for row in result]
patient_ids = [row[2] for row in result]
study_ids = [row[3] for row in result]
series_ids = [row[4] for row in result]
study_uids = [row[5] for row in result]
folder_names = [row[6] for row in result]
series_instance_uids = [row[7] for row in result]
session.close()
df = pd.DataFrame({'node_time': node_times, 'label_id': label_ids, 'patient_id': patient_ids, 'study_id': study_ids, 'series_id': series_ids, 'study_uid': study_uids, 'folder_name': folder_names, 'series_instance_uid': series_instance_uids})
df["patient_id"] = df["patient_id"].astype(str)
df["study_id"] = df["study_id"].astype(str)
df["series_id"] = df["series_id"].astype(str)
df["study_uid"] = df["study_uid"].astype(str)
df["folder_name"] = df["folder_name"].astype(str)
df["series_instance_uid"] = df["series_instance_uid"].astype(str)
return df
def select_single_label_id(label_id=None):
session = conect_mysql()
label = session.query(UserLabel).filter(
and_(UserLabel.id == label_id)).first()
if label is None:
return None, f"{label_id}, 标注数据不存在"
node_time = label.node_time
bundle = session.query(DicomSeries).filter(
and_(DicomSeries.id == label.series_id)).first()
if bundle is None:
return None, f"{label_id}, 关联的dicom数据不存在"
delineations = session.query(UserLabelDelineation).filter(
and_(UserLabelDelineation.label_id == label_id, UserLabelDelineation.status == 0)).order_by(
UserLabelDelineation.z_index.asc()).all()
session.close()
return (label, bundle, delineations), "success"
def generate_single_series_data_by_label_id(label_id=None, dicom_folder="", crop_size_3d=[48, 256, 256], expand=40, rotate_count=10, return_2d_data_flag=False):
(label, bundle, delineations), result = select_single_label_id(label_id=label_id)
patient_id = bundle.patient_id
series_instance_uid = bundle.series_instance_uid
data, selected_box, rotated_data_3d, node_time, rotated_data_2d, update_select_box_info_dict = None, None, None, None, None, None
if result != "success":
return data, selected_box, rotated_data_3d, node_time, rotated_data_2d, update_select_box_info_dict, patient_id, series_instance_uid
dicom_path = f"{dicom_folder}/{patient_id}-{series_instance_uid}"
cts = get_cts(dicom_path)
data = cts.get_raw_data()
spacing = cts.get_raw_spacing()
mask = np.zeros((len(data), len(data[1]), len(data[1][1])), np.uint8)
node_time = label.node_time
z_count = 0
for delineation in delineations:
if (delineation.contour is None or len(delineation.contour) == 0) and delineation.meta is None:
continue
indexlist, indexs, img_np = base64_to_list(delineation.contour)
if delineation.contour is None and delineation.meta:
indexlist, indexs, img_np = meta_to_list(delineation.meta, mask[0].copy())
mask[delineation.z_index] = img_np
z_count += 1
if mask is not None and np.sum(mask == 1) > 0:
coords = np.asarray(np.where(mask == 1))
selected_box = np.zeros((3, 2), np.float32)
selected_box[:, 0] = coords.min(axis=1)
selected_box[:, 1] = coords.max(axis=1) + 1
if selected_box[0][1] - selected_box[0][0] != z_count:
logger.info(f"z轴长度不一致, selected_box: {selected_box}, z_count: {z_count}")
selected_box[0][1] = selected_box[0][0] + z_count
if selected_box[0][1] - selected_box[0][0] > crop_size_3d[0]:
logger.info(f"z轴长度超过crop_size_3d, selected_box: {selected_box}, crop_size_3d: {crop_size_3d}")
selected_box[0][1] = selected_box[0][0] + crop_size_3d[0]
logger.info(f"selected_box: {selected_box}")
rotated_data_3d, update_select_box_info_dict = data_rotation_3d(data=data, select_box=selected_box, crop_size=crop_size_3d, expand=expand, num_rotations=rotate_count)
z_index_list_3d = update_select_box_info_dict["z_index_list"]
z_index_list_2d = []
if return_2d_data_flag:
rotated_data_2d = np.transpose(rotated_data_3d, (1, 0, 2, 3))
z_index_dict = update_select_box_info_dict["z_index_dict"]
z_index_list = update_select_box_info_dict["z_index_list"]
z_min, z_max = selected_box[0][0], min(selected_box[0][1], z_index_list[-1])
if z_min > z_max:
raise Exception(f"generate 2d data, z_min: {z_min} > z_max: {z_max}")
update_z_index_list = list(range(z_index_dict[z_min], z_index_dict[z_max]))
rotated_data_2d = rotated_data_2d[update_z_index_list]
z_index_list_2d = [idx_z for idx_z in range(int(z_min), int(z_max))]
return data, selected_box, rotated_data_3d, node_time, rotated_data_2d, update_select_box_info_dict, patient_id, series_instance_uid, z_index_list_3d, z_index_list_2d
def generate_npy_data_by_single_label_id(label_id=None, generate_3d_npy_data_flag=True, generate_2d_npy_data_flag=True, dicom_folder="", crop_size_3d=None, crop_size_2d=None, rotate_count=10, expand=40, save_path="", regular_class_3d=None, regular_class_2d=None):
if label_id is None:
return "label_id is None"
data, selected_box, rotated_data_3d, node_time, rotated_data_2d, update_select_box_info_dict, patient_id, series_instance_uid, z_index_list_3d, z_index_list_2d = generate_single_series_data_by_label_id(label_id=label_id, dicom_folder=dicom_folder, crop_size_3d=crop_size_3d, rotate_count=rotate_count, expand=expand, return_2d_data_flag=generate_2d_npy_data_flag)
if not os.path.exists(save_path):
Path(save_path).mkdir(parents=True, exist_ok=True)
npy_data_3d_file_list = []
npy_data_3d_z_index_list = []
npy_data_3d_rotate_count_list = []
npy_data_2d_file_list = []
npy_data_2d_z_index_list = []
npy_data_2d_rotate_count_list = []
if generate_3d_npy_data_flag and regular_class_3d:
regular_str = ""
if regular_class_3d is normalize_net_3d:
regular_str = "normalize_net_3d"
elif regular_class_3d is s3d_normalize_3d:
regular_str = "s3d_normalize_3d"
logger.info(f"data_3d shape: {rotated_data_3d.shape}, {regular_str}")
regular_data_3d = regular_class_3d(rotated_data_3d)
for idx_rotate_count in range(rotate_count):
idx_data_3d = regular_data_3d[idx_rotate_count, :, :, :]
idx_npy_data_3d_file = f"{save_path}/{node_time}_{label_id}_3d_rotate_10_{crop_size_3d[0]}_{crop_size_3d[1]}_{crop_size_3d[2]}_current_rotate_{idx_rotate_count+1}.npy"
np.save(idx_npy_data_3d_file, idx_data_3d)
logger.info(f"save 3d npy data -> {idx_npy_data_3d_file}, current rotate: {idx_rotate_count+1}")
npy_data_3d_file_list.append(idx_npy_data_3d_file)
npy_data_3d_z_index_list.append(z_index_list_3d)
npy_data_3d_rotate_count_list.append(idx_rotate_count+1)
if generate_2d_npy_data_flag and regular_class_2d:
regular_str = ""
if regular_class_2d is normalize_net_2d:
regular_str = "normalize_net_2d"
elif regular_class_2d is d2d_normalize:
regular_str = "d2d_normalize"
logger.info(f"data_2d shape: {rotated_data_2d.shape}, {regular_str}")
regular_data_2d = regular_class_2d(rotated_data_2d)
for idx_z_index_count, idx_z_index_2d in enumerate(z_index_list_2d):
for idx_rotate_count in range(rotate_count):
idx_data_2d = regular_data_2d[idx_z_index_count, idx_rotate_count, :, :]
idx_npy_data_2d_file = f"{save_path}/{node_time}_{label_id}_2d_rotate_10_{crop_size_2d[0]}_{crop_size_2d[1]}_z_{idx_z_index_2d}_current_rotate_{idx_rotate_count+1}.npy"
np.save(idx_npy_data_2d_file, idx_data_2d)
logger.info(f"save 2d npy data -> {idx_npy_data_2d_file}, current z_index: {idx_z_index_2d}, current rotate: {idx_rotate_count+1}")
npy_data_2d_file_list.append(idx_npy_data_2d_file)
npy_data_2d_z_index_list.append(idx_z_index_2d)
npy_data_2d_rotate_count_list.append(idx_rotate_count+1)
return patient_id, series_instance_uid, update_select_box_info_dict, npy_data_3d_file_list, npy_data_2d_file_list, npy_data_3d_z_index_list, npy_data_3d_rotate_count_list, npy_data_2d_z_index_list, npy_data_2d_rotate_count_list
def generate_npy_data_by_all_label_id_df(csv_file=None, npy_data_3d_file=None, npy_data_2d_file=None, dicom_folder="/opt/lung/ai", generate_3d_npy_data_flag=None, generate_2d_npy_data_flag=None, crop_size_3d=None, crop_size_2d=None, rotate_count=10, expand=40, regular_class_3d=None, regular_class_2d=None, save_path=""):
node_df = pd.read_csv(csv_file)
count = 0
data_3d_node_list = []
data_3d_label_id_list = []
data_3d_file_list = []
data_3d_z_index_list = []
data_3d_patient_id_list = []
data_3d_series_instance_uid_list = []
data_3d_rotate_count_list = []
data_2d_node_list = []
data_2d_label_id_list = []
data_2d_file_list = []
data_2d_z_index_list = []
data_2d_patient_id_list = []
data_2d_rotate_count_list = []
data_2d_series_instance_uid_list = []
for idx in tqdm(range(len(node_df))):
node_time = node_df.loc[idx, 'node_time']
label_id = node_df.loc[idx, 'label_id']
idx_patient_id = node_df.loc[idx, 'patient_id']
idx_series_instance_uid = node_df.loc[idx, 'series_instance_uid']
patient_id, series_instance_uid, update_select_box_info_dict, npy_data_3d_file_list, npy_data_2d_file_list, npy_data_3d_z_index_list, npy_data_3d_rotate_count_list, npy_data_2d_z_index_list, npy_data_2d_rotate_count_list = generate_npy_data_by_single_label_id(
label_id=label_id,
dicom_folder=dicom_folder,
generate_3d_npy_data_flag=generate_3d_npy_data_flag,
generate_2d_npy_data_flag=generate_2d_npy_data_flag,
crop_size_3d=crop_size_3d,
crop_size_2d=crop_size_2d,
rotate_count=rotate_count,
expand=expand,
save_path=save_path,
regular_class_3d=regular_class_3d,
regular_class_2d=regular_class_2d
)
if generate_3d_npy_data_flag:
assert len(npy_data_3d_file_list) == len(npy_data_3d_z_index_list) == len(npy_data_3d_rotate_count_list)
data_3d_node_list += [node_time] * len(npy_data_3d_file_list)
data_3d_label_id_list += [label_id] * len(npy_data_3d_file_list)
data_3d_file_list += npy_data_3d_file_list
data_3d_z_index_list += npy_data_3d_z_index_list
data_3d_rotate_count_list += npy_data_3d_rotate_count_list
data_3d_patient_id_list += [patient_id] * len(npy_data_3d_file_list)
data_3d_series_instance_uid_list += [series_instance_uid] * len(npy_data_3d_file_list)
if generate_2d_npy_data_flag:
assert len(npy_data_2d_file_list) == len(npy_data_2d_z_index_list) == len(npy_data_2d_rotate_count_list)
data_2d_node_list += [node_time] * len(npy_data_2d_file_list)
data_2d_label_id_list += [label_id] * len(npy_data_2d_file_list)
data_2d_file_list += npy_data_2d_file_list
data_2d_z_index_list += npy_data_2d_z_index_list
data_2d_rotate_count_list += npy_data_2d_rotate_count_list
data_2d_patient_id_list += [patient_id] * len(npy_data_2d_file_list)
data_2d_series_instance_uid_list += [series_instance_uid] * len(npy_data_2d_file_list)
if generate_3d_npy_data_flag:
npy_data_3d_df = pd.DataFrame({
"node": data_3d_node_list,
"label_id": data_3d_label_id_list,
"z_index": data_3d_z_index_list,
"rotate_count": data_3d_rotate_count_list,
"patient_id": data_3d_patient_id_list,
"series_instance_uid": data_3d_series_instance_uid_list,
"npy_file": data_3d_file_list,
})
npy_data_3d_df.to_csv(npy_data_3d_file, index=False, encoding="utf-8")
if generate_2d_npy_data_flag:
npy_data_2d_df = pd.DataFrame({
"node": data_2d_node_list,
"label_id": data_2d_label_id_list,
"z_index": data_2d_z_index_list,
"rotate_count": data_2d_rotate_count_list,
"patient_id": data_2d_patient_id_list,
"series_instance_uid": data_2d_series_instance_uid_list,
"npy_file": data_2d_file_list,
})
npy_data_2d_df.to_csv(npy_data_2d_file, index=False, encoding="utf-8")
if generate_3d_npy_data_flag:
logger.info(f"数据处理,保存 npy csv , npy data_3d -> {npy_data_3d_file}")
if generate_2d_npy_data_flag:
logger.info(f"数据处理,保存 npy csv , npy data_2d -> {npy_data_2d_file}")
return
def get_node_time_all_label_ids_df(node_time=None, csv_data_dir=""):
if node_time is None:
return None
df = generate_node_all_label_id_df(node_time=node_time)
task_info = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_file = f"{csv_data_dir}/{node_time}/{node_time}_{task_info}_rotate_10.csv"
Path(csv_file).parent.mkdir(parents=True, exist_ok=True)
df.to_csv(csv_file, index=False, encoding="utf-8")
logger.info(f"save csv data -> {csv_file}")
return csv_file
def generate_train_npy_csv_file(node_npy_pos_neg_list = None, net_id_list = None, net_id_crop_size_dict = None, node_net_id_npy_file_dict = None, csv_data_dir="", train_csv_dir="", train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, is_pad_df=True, is_save_csv=False, seed=100004):
def check_npy_file(node=None, net_id=None, crop_size=None, npy_file_list=None):
check_list = []
if net_id == "3d":
crop_size_str = f"_{crop_size[0]}_{crop_size[1]}_{crop_size[2]}_"
elif net_id == "2d":
crop_size_str = f"_{crop_size[0]}_{crop_size[1]}_"
elif net_id == "2d3d":
crop_size_str_2d = f"_{crop_size[1]}_{crop_size[2]}_"
crop_size_str_3d = f"_{crop_size[0]}_{crop_size[1]}_{crop_size[2]}_"
elif net_id == "s3d":
crop_size_str = f"_{crop_size[0]}_{crop_size[1]}_{crop_size[2]}_"
elif net_id == "d2d":
crop_size_str = f"_{crop_size[0]}_{crop_size[1]}_"
if net_id == "2d3d":
for idx_npy_file in npy_file_list:
if f"{node}_" in idx_npy_file and (crop_size_str_2d in idx_npy_file or crop_size_str_3d in idx_npy_file):
check_list.append(True)
else:
for idx_npy_file in npy_file_list:
if f"{node}_" in idx_npy_file and crop_size_str in idx_npy_file:
check_list.append(True)
return len(check_list) == len(npy_file_list)
for node_net, npy_file_list in node_net_id_npy_file_dict.items():
node, net_id = node_net[0], node_net[1]
crop_size = net_id_crop_size_dict[net_id]
if net_id == "2d3d":
npy_file_list = npy_file_list["2d"] + npy_file_list["3d"]
if not check_npy_file(node=node, net_id=net_id, crop_size=crop_size, npy_file_list=npy_file_list):
print(f"{node_net} npy_file_list check failed")
from sklearn.model_selection import train_test_split
def pad_df(df, max_len):
if len(df) == max_len:
return df
elif len(df) > max_len:
return df[:max_len]
else:
pad_df_list = [df]
lens = len(df)
while lens < max_len:
pad_df_list.append(df)
lens += len(df)
pad_df = pd.concat(pad_df_list, ignore_index=True)
return pad_df[:max_len]
# 遍历正负node, 遍历net_id, 返回对应的训练数据
node_net_file_dict = {}
for idx_node_pos_neg_dict in node_npy_pos_neg_list:
idx_node_pos_list = idx_node_pos_neg_dict["pos"]
idx_node_neg_list = idx_node_pos_neg_dict["neg"]
idx_node_pos_list_str = "_".join(list(map(str, idx_node_pos_list)))
idx_node_neg_list_str = "_".join(list(map(str, idx_node_neg_list)))
for idx_net_id in net_id_list:
idx_task_str = f"{idx_node_neg_list_str}_{idx_node_pos_list_str}_net_id_{idx_net_id}"
if idx_net_id == "2d3d":
idx_pos_npy_file_list_2d = []
idx_pos_npy_file_list_3d = []
idx_neg_npy_file_list_2d = []
idx_neg_npy_file_list_3d = []
for idx_node_pos in idx_node_pos_list:
idx_pos_npy_file_list_2d += node_net_id_npy_file_dict[(idx_node_pos, idx_net_id)]["2d"]
idx_pos_npy_file_list_3d += node_net_id_npy_file_dict[(idx_node_pos, idx_net_id)]["3d"]
for idx_node_neg in idx_node_neg_list:
idx_neg_npy_file_list_2d += node_net_id_npy_file_dict[(idx_node_neg, idx_net_id)]["2d"]
idx_neg_npy_file_list_3d += node_net_id_npy_file_dict[(idx_node_neg, idx_net_id)]["3d"]
node_net_file_dict[(idx_task_str, idx_net_id)] = {
"pos_2d": idx_pos_npy_file_list_2d,
"pos_3d": idx_pos_npy_file_list_3d,
"neg_2d": idx_neg_npy_file_list_2d,
"neg_3d": idx_neg_npy_file_list_3d
}
else:
idx_pos_npy_file_list = []
idx_neg_npy_file_list = []
for idx_node_pos in idx_node_pos_list:
idx_pos_npy_file_list += node_net_id_npy_file_dict[(idx_node_pos, idx_net_id)]
for idx_node_neg in idx_node_neg_list:
idx_neg_npy_file_list += node_net_id_npy_file_dict[(idx_node_neg, idx_net_id)]
node_net_file_dict[(idx_task_str, idx_net_id)] = {
"pos": idx_pos_npy_file_list,
"neg": idx_neg_npy_file_list
}
for idx_task_str_net_id, idx_pos_neg_npy_file_dict in node_net_file_dict.items():
idx_task_str, idx_net_id = idx_task_str_net_id[0], idx_task_str_net_id[1]
if idx_net_id == "2d3d":
idx_pos_2d_npy_file_list = idx_pos_neg_npy_file_dict["pos_2d"]
idx_pos_3d_npy_file_list = idx_pos_neg_npy_file_dict["pos_3d"]
idx_neg_2d_npy_file_list = idx_pos_neg_npy_file_dict["neg_2d"]
idx_neg_3d_npy_file_list = idx_pos_neg_npy_file_dict["neg_3d"]
idx_pos_2d_node_list = [idx_pos_2d_npy_file.split('_')[0] for idx_pos_2d_npy_file in idx_pos_2d_npy_file_list]
idx_pos_3d_node_list = [idx_pos_3d_npy_file.split('_')[0] for idx_pos_3d_npy_file in idx_pos_3d_npy_file_list]
idx_neg_2d_node_list = [idx_neg_2d_npy_file.split('_')[0] for idx_neg_2d_npy_file in idx_neg_2d_npy_file_list]
idx_neg_3d_node_list = [idx_neg_3d_npy_file.split('_')[0] for idx_neg_3d_npy_file in idx_neg_3d_npy_file_list]
logger.info(f"idx_task_str: {idx_task_str}, idx_pos_2d_node_list: {idx_pos_2d_node_list}, {idx_pos_2d_npy_file_list}\nidx_pos_3d_node_list: {idx_pos_3d_node_list}, {idx_pos_3d_npy_file_list}\nidx_neg_2d_node_list: {idx_neg_2d_node_list}, {idx_neg_2d_npy_file_list}\nidx_neg_3d_node_list: {idx_neg_3d_node_list}, {idx_neg_3d_npy_file_list}")
_idx_pos_2d_df_list = [pd.read_csv(f"{csv_data_dir}/{idx_pos_2d_npy_file.split('_')[0]}/{idx_pos_2d_npy_file}") for idx_pos_2d_npy_file in idx_pos_2d_npy_file_list]
_idx_pos_3d_df_list = [pd.read_csv(f"{csv_data_dir}/{idx_pos_3d_npy_file.split('_')[0]}/{idx_pos_3d_npy_file}") for idx_pos_3d_npy_file in idx_pos_3d_npy_file_list]
_idx_neg_2d_df_list = [pd.read_csv(f"{csv_data_dir}/{idx_neg_2d_npy_file.split('_')[0]}/{idx_neg_2d_npy_file}") for idx_neg_2d_npy_file in idx_neg_2d_npy_file_list]
_idx_neg_3d_df_list = [pd.read_csv(f"{csv_data_dir}/{idx_neg_3d_npy_file.split('_')[0]}/{idx_neg_3d_npy_file}") for idx_neg_3d_npy_file in idx_neg_3d_npy_file_list]
idx_pos_2d_df_list = []
idx_pos_3d_df_list = []
idx_neg_2d_df_list = []
idx_neg_3d_df_list = []
for idx_node, idx_df in zip(idx_pos_2d_node_list, _idx_pos_2d_df_list):
idx_df["node"] = idx_node
idx_pos_2d_df_list.append(idx_df)
for idx_node, idx_df in zip(idx_pos_3d_node_list, _idx_pos_3d_df_list):
idx_df["node"] = idx_node
idx_pos_3d_df_list.append(idx_df)
for idx_node, idx_df in zip(idx_neg_2d_node_list, _idx_neg_2d_df_list):
idx_df["node"] = idx_node
idx_neg_2d_df_list.append(idx_df)
for idx_node, idx_df in zip(idx_neg_3d_node_list, _idx_neg_3d_df_list):
idx_df["node"] = idx_node
idx_neg_3d_df_list.append(idx_df)
idx_pos_2d_train_df_list = []
idx_pos_2d_val_df_list = []
idx_pos_2d_test_df_list = []
idx_pos_3d_train_df_list = []
idx_pos_3d_val_df_list = []
idx_pos_3d_test_df_list = []
idx_neg_2d_train_df_list = []
idx_neg_2d_val_df_list = []
idx_neg_2d_test_df_list = []
idx_neg_3d_train_df_list = []
idx_neg_3d_val_df_list = []
idx_neg_3d_test_df_list = []
for idx_pos_2d_df in idx_pos_2d_df_list:
idx_pos_2d_train_df, idx_pos_2d_test_val_df = train_test_split(idx_pos_2d_df, test_size=1-train_ratio, random_state=seed)
idx_pos_2d_val_df, idx_pos_2d_test_df = train_test_split(idx_pos_2d_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_pos_2d_train_df_list.append(idx_pos_2d_train_df)
idx_pos_2d_val_df_list.append(idx_pos_2d_val_df)
idx_pos_2d_test_df_list.append(idx_pos_2d_test_df)
for idx_pos_3d_df in idx_pos_3d_df_list:
idx_pos_3d_train_df, idx_pos_3d_test_val_df = train_test_split(idx_pos_3d_df, test_size=1-train_ratio, random_state=seed)
idx_pos_3d_val_df, idx_pos_3d_test_df = train_test_split(idx_pos_3d_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_pos_3d_train_df_list.append(idx_pos_3d_train_df)
idx_pos_3d_val_df_list.append(idx_pos_3d_val_df)
idx_pos_3d_test_df_list.append(idx_pos_3d_test_df)
for idx_neg_2d_df in idx_neg_2d_df_list:
idx_neg_2d_train_df, idx_neg_2d_test_val_df = train_test_split(idx_neg_2d_df, test_size=1-train_ratio, random_state=seed)
idx_neg_2d_val_df, idx_neg_2d_test_df = train_test_split(idx_neg_2d_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_neg_2d_train_df_list.append(idx_neg_2d_train_df)
idx_neg_2d_val_df_list.append(idx_neg_2d_val_df)
idx_neg_2d_test_df_list.append(idx_neg_2d_test_df)
for idx_neg_3d_df in idx_neg_3d_df_list:
idx_neg_3d_train_df, idx_neg_3d_test_val_df = train_test_split(idx_neg_3d_df, test_size=1-train_ratio, random_state=seed)
idx_neg_3d_val_df, idx_neg_3d_test_df = train_test_split(idx_neg_3d_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_neg_3d_train_df_list.append(idx_neg_3d_train_df)
idx_neg_3d_val_df_list.append(idx_neg_3d_val_df)
idx_neg_3d_test_df_list.append(idx_neg_3d_test_df)
idx_pos_2d_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_train_df_list]
idx_pos_2d_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_val_df_list]
idx_pos_2d_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_test_df_list]
idx_pos_3d_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_train_df_list]
idx_pos_3d_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_val_df_list]
idx_pos_3d_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_test_df_list]
idx_neg_2d_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_train_df_list]
idx_neg_2d_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_val_df_list]
idx_neg_2d_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_test_df_list]
idx_neg_3d_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_train_df_list]
idx_neg_3d_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_val_df_list]
idx_neg_3d_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_test_df_list]
print(f"idx_task_str: {idx_task_str}, 同类数据填充 data_2d, before\n训练集: {idx_pos_2d_train_df_lens_list}, {idx_neg_2d_train_df_lens_list}\n验证集: {idx_pos_2d_val_df_lens_list}, {idx_neg_2d_val_df_lens_list}\n测试集: {idx_pos_2d_test_df_lens_list}, {idx_neg_2d_test_df_lens_list}")
print(f"idx_task_str: {idx_task_str}, 同类数据填充 data_3d, before\n训练集: {idx_pos_3d_train_df_lens_list}, {idx_neg_3d_train_df_lens_list}\n验证集: {idx_pos_3d_val_df_lens_list}, {idx_neg_3d_val_df_lens_list}\n测试集: {idx_pos_3d_test_df_lens_list}, {idx_neg_3d_test_df_lens_list}")
if is_pad_df:
idx_pos_2d_train_df_list = [pad_df(idx_df, max(idx_pos_2d_train_df_lens_list)) for idx_df in idx_pos_2d_train_df_list]
# idx_pos_2d_val_df_list = [pad_df(idx_df, max(idx_pos_2d_val_df_lens_list)) for idx_df in idx_pos_2d_val_df_list]
# idx_pos_2d_test_df_list = [pad_df(idx_df, max(idx_pos_2d_test_df_lens_list)) for idx_df in idx_pos_2d_test_df_list]
idx_pos_3d_train_df_list = [pad_df(idx_df, max(idx_pos_3d_train_df_lens_list)) for idx_df in idx_pos_3d_train_df_list]
# idx_pos_3d_val_df_list = [pad_df(idx_df, max(idx_pos_3d_val_df_lens_list)) for idx_df in idx_pos_3d_val_df_list]
# idx_pos_3d_test_df_list = [pad_df(idx_df, max(idx_pos_3d_test_df_lens_list)) for idx_df in idx_pos_3d_test_df_list]
idx_neg_2d_train_df_list = [pad_df(idx_df, max(idx_neg_2d_train_df_lens_list)) for idx_df in idx_neg_2d_train_df_list]
# idx_neg_2d_val_df_list = [pad_df(idx_df, max(idx_neg_2d_val_df_lens_list)) for idx_df in idx_neg_2d_val_df_list]
# idx_neg_2d_test_df_list = [pad_df(idx_df, max(idx_neg_2d_test_df_lens_list)) for idx_df in idx_neg_2d_test_df_list]
idx_neg_3d_train_df_list = [pad_df(idx_df, max(idx_neg_3d_train_df_lens_list)) for idx_df in idx_neg_3d_train_df_list]
# idx_neg_3d_val_df_list = [pad_df(idx_df, max(idx_neg_3d_val_df_lens_list)) for idx_df in idx_neg_3d_val_df_list]
# idx_neg_3d_test_df_list = [pad_df(idx_df, max(idx_neg_3d_test_df_lens_list)) for idx_df in idx_neg_3d_test_df_list]
idx_pos_2d_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_train_df_list]
idx_pos_2d_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_val_df_list]
idx_pos_2d_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_test_df_list]
idx_pos_3d_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_train_df_list]
idx_pos_3d_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_val_df_list]
idx_pos_3d_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_test_df_list]
idx_neg_2d_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_train_df_list]
idx_neg_2d_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_val_df_list]
idx_neg_2d_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_test_df_list]
idx_neg_3d_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_train_df_list]
idx_neg_3d_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_val_df_list]
idx_neg_3d_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_test_df_list]
print(f"idx_task_str: {idx_task_str}, 同类数据填充 data_2d, after\n训练集: {idx_pos_2d_train_df_lens_list}, {idx_neg_2d_train_df_lens_list}\n验证集: {idx_pos_2d_val_df_lens_list}, {idx_neg_2d_val_df_lens_list}\n测试集: {idx_pos_2d_test_df_lens_list}, {idx_neg_2d_test_df_lens_list}")
print(f"idx_task_str: {idx_task_str}, 同类数据填充 data_3d, after\n训练集: {idx_pos_3d_train_df_lens_list}, {idx_neg_3d_train_df_lens_list}\n验证集: {idx_pos_3d_val_df_lens_list}, {idx_neg_3d_val_df_lens_list}\n测试集: {idx_pos_3d_test_df_lens_list}, {idx_neg_3d_test_df_lens_list}")
if is_pad_df:
# 训练集
idx_pos_2d_train_df_lens = sum(idx_pos_2d_train_df_lens_list)
idx_pos_3d_train_df_lens = sum(idx_pos_3d_train_df_lens_list)
idx_neg_2d_train_df_lens = sum(idx_neg_2d_train_df_lens_list)
idx_neg_3d_train_df_lens = sum(idx_neg_3d_train_df_lens_list)
if idx_pos_2d_train_df_lens > idx_neg_2d_train_df_lens:
idx_neg_2d_each_lens = idx_pos_2d_train_df_lens // len(idx_neg_2d_train_df_lens_list)
idx_neg_2d_train_df_list = [pad_df(idx_df, idx_neg_2d_each_lens) for idx_df in idx_neg_2d_train_df_list]
elif idx_pos_2d_train_df_lens < idx_neg_2d_train_df_lens:
idx_pos_2d_each_lens = idx_neg_2d_train_df_lens // len(idx_pos_2d_train_df_lens_list)
idx_pos_2d_train_df_list = [pad_df(idx_df, idx_pos_2d_each_lens) for idx_df in idx_pos_2d_train_df_list]
if idx_pos_3d_train_df_lens > idx_neg_3d_train_df_lens:
idx_neg_3d_each_lens = idx_pos_3d_train_df_lens // len(idx_neg_3d_train_df_lens_list)
idx_neg_3d_train_df_list = [pad_df(idx_df, idx_neg_3d_each_lens) for idx_df in idx_neg_3d_train_df_list]
elif idx_pos_3d_train_df_lens < idx_neg_3d_train_df_lens:
idx_pos_3d_each_lens = idx_neg_3d_train_df_lens // len(idx_pos_3d_train_df_lens_list)
idx_pos_3d_train_df_list = [pad_df(idx_df, idx_pos_3d_each_lens) for idx_df in idx_pos_3d_train_df_list]
# # 验证集
# idx_pos_2d_val_df_lens = sum(idx_pos_2d_val_df_lens_list)
# idx_neg_2d_val_df_lens = sum(idx_neg_2d_val_df_lens_list)
# idx_pos_3d_val_df_lens = sum(idx_pos_3d_val_df_lens_list)
# idx_neg_3d_val_df_lens = sum(idx_neg_3d_val_df_lens_list)
# if idx_pos_2d_val_df_lens > idx_neg_2d_val_df_lens:
# idx_neg_2d_each_lens = idx_pos_2d_val_df_lens // len(idx_neg_2d_val_df_lens_list)
# idx_neg_2d_val_df_list = [pad_df(idx_df, idx_neg_2d_each_lens) for idx_df in idx_neg_2d_val_df_list]
# elif idx_pos_2d_val_df_lens < idx_neg_2d_val_df_lens:
# idx_pos_2d_each_lens = idx_neg_2d_val_df_lens // len(idx_pos_2d_val_df_lens_list)
# idx_pos_2d_val_df_list = [pad_df(idx_df, idx_pos_2d_each_lens) for idx_df in idx_pos_2d_val_df_list]
# if idx_pos_3d_val_df_lens > idx_neg_3d_val_df_lens:
# idx_neg_3d_each_lens = idx_pos_3d_val_df_lens // len(idx_neg_3d_val_df_lens_list)
# idx_neg_3d_val_df_list = [pad_df(idx_df, idx_neg_3d_each_lens) for idx_df in idx_neg_3d_val_df_list]
# elif idx_pos_3d_val_df_lens < idx_neg_3d_val_df_lens:
# idx_pos_3d_each_lens = idx_neg_3d_val_df_lens // len(idx_pos_3d_val_df_lens_list)
# idx_pos_3d_val_df_list = [pad_df(idx_df, idx_pos_3d_each_lens) for idx_df in idx_pos_3d_val_df_list]
# # 测试集
# idx_pos_2d_test_df_lens = sum(idx_pos_2d_test_df_lens_list)
# idx_neg_2d_test_df_lens = sum(idx_neg_2d_test_df_lens_list)
# idx_pos_3d_test_df_lens = sum(idx_pos_3d_test_df_lens_list)
# idx_neg_3d_test_df_lens = sum(idx_neg_3d_test_df_lens_list)
# if idx_pos_2d_test_df_lens > idx_neg_2d_test_df_lens:
# idx_neg_2d_each_lens = idx_pos_2d_test_df_lens // len(idx_neg_2d_test_df_lens_list)
# idx_neg_2d_test_df_list = [pad_df(idx_df, idx_neg_2d_each_lens) for idx_df in idx_neg_2d_test_df_list]
# elif idx_pos_2d_test_df_lens < idx_neg_2d_test_df_lens:
# idx_pos_2d_each_lens = idx_neg_2d_test_df_lens // len(idx_pos_2d_test_df_lens_list)
# idx_pos_2d_test_df_list = [pad_df(idx_df, idx_pos_2d_each_lens) for idx_df in idx_pos_2d_test_df_list]
# if idx_pos_3d_test_df_lens > idx_neg_3d_test_df_lens:
# idx_neg_3d_each_lens = idx_pos_3d_test_df_lens // len(idx_neg_3d_test_df_lens_list)
# idx_neg_3d_test_df_list = [pad_df(idx_df, idx_neg_3d_each_lens) for idx_df in idx_neg_3d_test_df_list]
# elif idx_pos_3d_test_df_lens < idx_neg_3d_test_df_lens:
# idx_pos_3d_each_lens = idx_neg_3d_test_df_lens // len(idx_pos_3d_test_df_lens_list)
# idx_pos_3d_test_df_list = [pad_df(idx_df, idx_pos_3d_each_lens) for idx_df in idx_pos_3d_test_df_list]
idx_pos_2d_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_train_df_list]
idx_pos_2d_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_val_df_list]
idx_pos_2d_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_2d_test_df_list]
idx_pos_3d_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_train_df_list]
idx_pos_3d_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_val_df_list]
idx_pos_3d_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_3d_test_df_list]
idx_neg_2d_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_train_df_list]
idx_neg_2d_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_val_df_list]
idx_neg_2d_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_2d_test_df_list]
idx_neg_3d_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_train_df_list]
idx_neg_3d_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_val_df_list]
idx_neg_3d_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_3d_test_df_list]
print(f"idx_task_str: {idx_task_str}, 平衡后数据 data_2d\n训练集: {idx_pos_2d_train_df_lens_list}, {idx_neg_2d_train_df_lens_list}\n验证集: {idx_pos_2d_val_df_lens_list}, {idx_neg_2d_val_df_lens_list}\n测试集: {idx_pos_2d_test_df_lens_list}, {idx_neg_2d_test_df_lens_list}")
print(f"idx_task_str: {idx_task_str}, 平衡后数据 data_3d\n训练集: {idx_pos_3d_train_df_lens_list}, {idx_neg_3d_train_df_lens_list}\n验证集: {idx_pos_3d_val_df_lens_list}, {idx_neg_3d_val_df_lens_list}\n测试集: {idx_pos_3d_test_df_lens_list}, {idx_neg_3d_test_df_lens_list}")
idx_pos_2d_train_df = pd.concat(idx_pos_2d_train_df_list, ignore_index=True)
idx_pos_2d_train_df['label'] = 1
idx_neg_2d_train_df = pd.concat(idx_neg_2d_train_df_list, ignore_index=True)
idx_neg_2d_train_df['label'] = 0
idx_data_2d_train_df = pd.concat([idx_pos_2d_train_df, idx_neg_2d_train_df], ignore_index=True)
idx_pos_3d_train_df = pd.concat(idx_pos_3d_train_df_list, ignore_index=True)
idx_pos_3d_train_df['label'] = 1
idx_neg_3d_train_df = pd.concat(idx_neg_3d_train_df_list, ignore_index=True)
idx_neg_3d_train_df['label'] = 0
idx_data_3d_train_df = pd.concat([idx_pos_3d_train_df, idx_neg_3d_train_df], ignore_index=True)
idx_pos_2d_val_df = pd.concat(idx_pos_2d_val_df_list, ignore_index=True)
idx_pos_2d_val_df['label'] = 1
idx_neg_2d_val_df = pd.concat(idx_neg_2d_val_df_list, ignore_index=True)
idx_neg_2d_val_df['label'] = 0
idx_data_2d_val_df = pd.concat([idx_pos_2d_val_df, idx_neg_2d_val_df], ignore_index=True)
idx_pos_3d_val_df = pd.concat(idx_pos_3d_val_df_list, ignore_index=True)
idx_pos_3d_val_df['label'] = 1
idx_neg_3d_val_df = pd.concat(idx_neg_3d_val_df_list, ignore_index=True)
idx_neg_3d_val_df['label'] = 0
idx_data_3d_val_df = pd.concat([idx_pos_3d_val_df, idx_neg_3d_val_df], ignore_index=True)
idx_pos_2d_test_df = pd.concat(idx_pos_2d_test_df_list, ignore_index=True)
idx_pos_2d_test_df['label'] = 1
idx_neg_2d_test_df = pd.concat(idx_neg_2d_test_df_list, ignore_index=True)
idx_neg_2d_test_df['label'] = 0
idx_data_2d_test_df = pd.concat([idx_pos_2d_test_df, idx_neg_2d_test_df], ignore_index=True)
idx_pos_3d_test_df = pd.concat(idx_pos_3d_test_df_list, ignore_index=True)
idx_pos_3d_test_df['label'] = 1
idx_neg_3d_test_df = pd.concat(idx_neg_3d_test_df_list, ignore_index=True)
idx_neg_3d_test_df['label'] = 0
idx_data_3d_test_df = pd.concat([idx_pos_3d_test_df, idx_neg_3d_test_df], ignore_index=True)
idx_2d_train_df_file = f"{train_csv_dir}/{idx_task_str}_data_2d_train.csv"
idx_2d_val_df_file = f"{train_csv_dir}/{idx_task_str}_data_2d_val.csv"
idx_2d_test_df_file = f"{train_csv_dir}/{idx_task_str}_data_2d_test.csv"
idx_3d_train_df_file = f"{train_csv_dir}/{idx_task_str}_data_3d_train.csv"
idx_3d_val_df_file = f"{train_csv_dir}/{idx_task_str}_data_3d_val.csv"
idx_3d_test_df_file = f"{train_csv_dir}/{idx_task_str}_data_3d_test.csv"
data_2d_train_lens = len(idx_data_2d_train_df)
data_3d_train_lens = len(idx_data_3d_train_df)
print(f"idx_task_str: {idx_task_str}, data_2d, data_3d before 训练集\n{data_2d_train_lens}, {data_3d_train_lens}")
if data_2d_train_lens > data_3d_train_lens:
idx_data_3d_train_df = pad_df(idx_data_3d_train_df, data_2d_train_lens)
elif data_2d_train_lens < data_3d_train_lens:
idx_data_2d_train_df = pad_df(idx_data_2d_train_df, data_3d_train_lens)
data_2d_train_lens = len(idx_data_2d_train_df)
data_3d_train_lens = len(idx_data_3d_train_df)
print(f"idx_task_str: {idx_task_str}, data_2d, data_3d after 训练集\n{data_2d_train_lens}, {data_3d_train_lens}")
print(f"idx_task_str: {idx_task_str}\ntrain_2d_df: {len(idx_data_2d_train_df)}\nval_2d_df: {len(idx_data_2d_val_df)}\ntest_2d_df: {len(idx_data_2d_test_df)}\n")
print(f"idx_task_str: {idx_task_str}\ntrain_3d_df: {len(idx_data_3d_train_df)}\nval_3d_df: {len(idx_data_3d_val_df)}\ntest_3d_df: {len(idx_data_3d_test_df)}\n")
assert idx_data_2d_train_df['label'].isnull().sum() == 0
assert idx_data_3d_train_df['label'].isnull().sum() == 0
assert len(idx_data_2d_train_df) == len(idx_data_3d_train_df)
if is_save_csv:
idx_data_2d_train_df.to_csv(idx_2d_train_df_file, index=False, encoding="utf-8")
idx_data_2d_val_df.to_csv(idx_2d_val_df_file, index=False, encoding="utf-8")
idx_data_2d_test_df.to_csv(idx_2d_test_df_file, index=False, encoding="utf-8")
idx_data_3d_train_df.to_csv(idx_3d_train_df_file, index=False, encoding="utf-8")
idx_data_3d_val_df.to_csv(idx_3d_val_df_file, index=False, encoding="utf-8")
idx_data_3d_test_df.to_csv(idx_3d_test_df_file, index=False, encoding="utf-8")
logger.info(f"task_info: {idx_task_str}\ntrain_2d_df_file: {idx_2d_train_df_file}\nval_2d_df_file: {idx_2d_val_df_file}\ntest_2d_df_file: {idx_2d_test_df_file}\n")
logger.info(f"task_info: {idx_task_str}\ntrain_3d_df_file: {idx_3d_train_df_file}\nval_3d_df_file: {idx_3d_val_df_file}\ntest_3d_df_file: {idx_3d_test_df_file}\n")
else:
idx_pos_npy_file_list = idx_pos_neg_npy_file_dict["pos"]
idx_neg_npy_file_list = idx_pos_neg_npy_file_dict["neg"]
idx_pos_node_list = [idx_pos_npy_file.split('_')[0] for idx_pos_npy_file in idx_pos_npy_file_list]
idx_neg_node_list = [idx_neg_npy_file.split('_')[0] for idx_neg_npy_file in idx_neg_npy_file_list]
logger.info(f"idx_task_str: {idx_task_str}, idx_pos_node_list: {idx_pos_node_list}, {idx_pos_npy_file_list}\nidx_neg_node_list: {idx_neg_node_list}, {idx_neg_npy_file_list}")
_idx_pos_df_list = [pd.read_csv(f"{csv_data_dir}/{idx_pos_npy_file.split('_')[0]}/{idx_pos_npy_file}") for idx_pos_npy_file in idx_pos_npy_file_list]
_idx_neg_df_list = [pd.read_csv(f"{csv_data_dir}/{idx_neg_npy_file.split('_')[0]}/{idx_neg_npy_file}") for idx_neg_npy_file in idx_neg_npy_file_list]
idx_pos_df_list = []
idx_neg_df_list = []
for idx_node, idx_df in zip(idx_pos_node_list, _idx_pos_df_list):
idx_df["node"] = idx_node
idx_pos_df_list.append(idx_df)
for idx_node, idx_df in zip(idx_neg_node_list, _idx_neg_df_list):
idx_df["node"] = idx_node
idx_neg_df_list.append(idx_df)
idx_pos_train_df_list = []
idx_pos_val_df_list = []
idx_pos_test_df_list = []
idx_neg_train_df_list = []
idx_neg_val_df_list = []
idx_neg_test_df_list = []
for idx_pos_df in idx_pos_df_list:
logger.info(f"idx_task_str: {idx_task_str}, idx_pos_df: {len(idx_pos_df)}")
idx_pos_train_df, idx_pos_test_val_df = train_test_split(idx_pos_df, test_size=1-train_ratio, random_state=seed)
idx_pos_val_df, idx_pos_test_df = train_test_split(idx_pos_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_pos_train_df_list.append(idx_pos_train_df)
idx_pos_val_df_list.append(idx_pos_val_df)
idx_pos_test_df_list.append(idx_pos_test_df)
for idx_neg_df in idx_neg_df_list:
logger.info(f"idx_task_str: {idx_task_str}, idx_neg_df: {len(idx_neg_df)}")
idx_neg_train_df, idx_neg_test_val_df = train_test_split(idx_neg_df, test_size=1-train_ratio, random_state=seed)
idx_neg_val_df, idx_neg_test_df = train_test_split(idx_neg_test_val_df, test_size=test_ratio/(val_ratio+test_ratio), random_state=seed)
idx_neg_train_df_list.append(idx_neg_train_df)
idx_neg_val_df_list.append(idx_neg_val_df)
idx_neg_test_df_list.append(idx_neg_test_df)
idx_pos_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_train_df_list]
idx_pos_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_val_df_list]
idx_pos_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_test_df_list]
idx_neg_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_train_df_list]
idx_neg_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_val_df_list]
idx_neg_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_test_df_list]
print(f"idx_task_str: {idx_task_str}, 同类数据填充 before\n训练集: {idx_pos_train_df_lens_list}, {idx_neg_train_df_lens_list}\n验证集: {idx_pos_val_df_lens_list}, {idx_neg_val_df_lens_list}\n测试集: {idx_pos_test_df_lens_list}, {idx_neg_test_df_lens_list}")
if is_pad_df:
idx_pos_train_df_list = [pad_df(idx_df, max(idx_pos_train_df_lens_list)) for idx_df in idx_pos_train_df_list]
idx_neg_train_df_list = [pad_df(idx_df, max(idx_neg_train_df_lens_list)) for idx_df in idx_neg_train_df_list]
# idx_pos_val_df_list = [pad_df(idx_df, max(idx_pos_val_df_lens_list)) for idx_df in idx_pos_val_df_list]
# idx_neg_val_df_list = [pad_df(idx_df, max(idx_neg_val_df_lens_list)) for idx_df in idx_neg_val_df_list]
# idx_pos_test_df_list = [pad_df(idx_df, max(idx_pos_test_df_lens_list)) for idx_df in idx_pos_test_df_list]
# idx_neg_test_df_list = [pad_df(idx_df, max(idx_neg_test_df_lens_list)) for idx_df in idx_neg_test_df_list]
idx_pos_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_train_df_list]
idx_pos_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_val_df_list]
idx_pos_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_test_df_list]
idx_neg_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_train_df_list]
idx_neg_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_val_df_list]
idx_neg_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_test_df_list]
print(f"idx_task_str: {idx_task_str}, 同类数据填充 after\n训练集: {idx_pos_train_df_lens_list}, {idx_neg_train_df_lens_list}\n验证集: {idx_pos_val_df_lens_list}, {idx_neg_val_df_lens_list}\n测试集: {idx_pos_test_df_lens_list}, {idx_neg_test_df_lens_list}")
if is_pad_df:
# 训练集
idx_pos_train_df_lens = sum(idx_pos_train_df_lens_list)
idx_neg_train_df_lens = sum(idx_neg_train_df_lens_list)
if idx_pos_train_df_lens > idx_neg_train_df_lens:
idx_neg_each_lens = idx_pos_train_df_lens // len(idx_neg_train_df_lens_list)
idx_neg_train_df_list = [pad_df(idx_df, idx_neg_each_lens) for idx_df in idx_neg_train_df_list]
elif idx_pos_train_df_lens < idx_neg_train_df_lens:
idx_pos_each_lens = idx_neg_train_df_lens // len(idx_pos_train_df_lens_list)
idx_pos_train_df_list = [pad_df(idx_df, idx_pos_each_lens) for idx_df in idx_pos_train_df_list]
# # 验证集
# idx_pos_val_df_lens = sum(idx_pos_val_df_lens_list)
# idx_neg_val_df_lens = sum(idx_neg_val_df_lens_list)
# if idx_pos_val_df_lens > idx_neg_val_df_lens:
# idx_neg_each_lens = idx_pos_val_df_lens // len(idx_neg_val_df_lens_list)
# idx_neg_val_df_list = [pad_df(idx_df, idx_neg_each_lens) for idx_df in idx_neg_val_df_list]
# elif idx_pos_val_df_lens < idx_neg_val_df_lens:
# idx_pos_each_lens = idx_neg_val_df_lens // len(idx_pos_val_df_lens_list)
# idx_pos_val_df_list = [pad_df(idx_df, idx_pos_each_lens) for idx_df in idx_pos_val_df_list]
# # 测试集
# idx_pos_test_df_lens = sum(idx_pos_test_df_lens_list)
# idx_neg_test_df_lens = sum(idx_neg_test_df_lens_list)
# if idx_pos_test_df_lens > idx_neg_test_df_lens:
# idx_neg_each_lens = idx_pos_test_df_lens // len(idx_neg_test_df_lens_list)
# idx_neg_test_df_list = [pad_df(idx_df, idx_neg_each_lens) for idx_df in idx_neg_test_df_list]
# elif idx_pos_test_df_lens < idx_neg_test_df_lens:
# idx_pos_each_lens = idx_neg_test_df_lens // len(idx_pos_test_df_lens_list)
# idx_pos_test_df_list = [pad_df(idx_df, idx_pos_each_lens) for idx_df in idx_pos_test_df_list]
idx_pos_train_df_lens_list = [len(idx_df) for idx_df in idx_pos_train_df_list]
idx_pos_val_df_lens_list = [len(idx_df) for idx_df in idx_pos_val_df_list]
idx_pos_test_df_lens_list = [len(idx_df) for idx_df in idx_pos_test_df_list]
idx_neg_train_df_lens_list = [len(idx_df) for idx_df in idx_neg_train_df_list]
idx_neg_val_df_lens_list = [len(idx_df) for idx_df in idx_neg_val_df_list]
idx_neg_test_df_lens_list = [len(idx_df) for idx_df in idx_neg_test_df_list]
print(f"idx_task_str: {idx_task_str}, 平衡后数据\n训练集: {idx_pos_train_df_lens_list}, {idx_neg_train_df_lens_list}\n验证集: {idx_pos_val_df_lens_list}, {idx_neg_val_df_lens_list}\n测试集: {idx_pos_test_df_lens_list}, {idx_neg_test_df_lens_list}")
idx_pos_train_df = pd.concat(idx_pos_train_df_list, ignore_index=True)
idx_pos_train_df['label'] = 1
idx_neg_train_df = pd.concat(idx_neg_train_df_list, ignore_index=True)
idx_neg_train_df['label'] = 0
idx_train_df = pd.concat([idx_pos_train_df, idx_neg_train_df], ignore_index=True)
idx_pos_val_df = pd.concat(idx_pos_val_df_list, ignore_index=True)
idx_pos_val_df['label'] = 1
idx_neg_val_df = pd.concat(idx_neg_val_df_list, ignore_index=True)
idx_neg_val_df['label'] = 0
idx_val_df = pd.concat([idx_pos_val_df, idx_neg_val_df], ignore_index=True)
idx_pos_test_df = pd.concat(idx_pos_test_df_list, ignore_index=True)
idx_pos_test_df['label'] = 1
idx_neg_test_df = pd.concat(idx_neg_test_df_list, ignore_index=True)
idx_neg_test_df['label'] = 0
idx_test_df = pd.concat([idx_pos_test_df, idx_neg_test_df], ignore_index=True)
idx_train_df_file = f"{train_csv_dir}/{idx_task_str}_train.csv"
idx_val_df_file = f"{train_csv_dir}/{idx_task_str}_val.csv"
idx_test_df_file = f"{train_csv_dir}/{idx_task_str}_test.csv"
print(f"idx_task_str: {idx_task_str}\ntrain_df: {len(idx_train_df)}\nval_df: {len(idx_val_df)}\ntest_df: {len(idx_test_df)}\n")
assert idx_train_df['label'].isnull().sum() == 0
if is_save_csv:
idx_train_df.to_csv(idx_train_df_file, index=False, encoding="utf-8")
idx_val_df.to_csv(idx_val_df_file, index=False, encoding="utf-8")
idx_test_df.to_csv(idx_test_df_file, index=False, encoding="utf-8")
logger.info(f"task_info: {idx_task_str}\ntrain_df_file: {idx_train_df_file}\nval_df_file: {idx_val_df_file}\ntest_df_file: {idx_test_df_file}\n")
return
csv_data_dir = "/df_lung/cls_train_data/csv_data"
npy_data_dir = "/df_lung/cls_train_data/npy_data"
def get_train_data_info_csv(node_time_list=[]):
for node_time in node_time_list:
csv_file = get_node_time_all_label_ids_df(node_time=node_time, csv_data_dir=csv_data_dir)
logger.info(f"{node_time}: {csv_file}\n")
def process_npy(args):
generate_npy_data_by_all_label_id_df(
csv_file=args[0],
npy_data_3d_file=args[1],
npy_data_2d_file=args[2],
dicom_folder=args[3],
generate_3d_npy_data_flag=args[4],
generate_2d_npy_data_flag=args[5],
crop_size_3d=args[6],
crop_size_2d=args[7],
rotate_count=args[8],
expand=args[9],
regular_class_3d=args[10],
regular_class_2d=args[11],
save_path=args[12]
)
logger.info(f"process_npy finished: csv_file: {args[0]}, crop_size_3d: {args[6]}, crop_size_2d: {args[7]}")
def get_npy_data(node_csv_file_list=[], crop_size_list=[]):
if False in [False for _ in node_csv_file_list if f"{_[0]}_" not in _[1]]:
raise ValueError(f"node_csv_file_list: {node_csv_file_list}")
process_args_list = []
for node_time, csv_file in node_csv_file_list:
for idx_crop_size_dict in crop_size_list:
crop_size_3d = idx_crop_size_dict["crop_size_3d"]
crop_size_2d = idx_crop_size_dict["crop_size_2d"]
generate_3d_npy_data_flag = idx_crop_size_dict["generate_3d_npy_data_flag"]
generate_2d_npy_data_flag = idx_crop_size_dict["generate_2d_npy_data_flag"]
regular_class_3d = idx_crop_size_dict["regular_class_3d"]
regular_class_2d = idx_crop_size_dict["regular_class_2d"]
save_path = f"{npy_data_dir}/{node_time}"
Path(save_path).mkdir(parents=True, exist_ok=True)
idx_npy_data_3d_file = None
if generate_3d_npy_data_flag:
idx_npy_data_3d_file = f"{csv_file.replace('.csv', f'_{crop_size_3d[0]}_{crop_size_3d[1]}_{crop_size_3d[2]}_npy_data_3d')}.csv"
idx_npy_data_3d_file = os.path.join(save_path, idx_npy_data_3d_file)
idx_npy_data_2d_file = None
if generate_2d_npy_data_flag:
idx_npy_data_2d_file = f"{csv_file.replace('.csv', f'_{crop_size_2d[0]}_{crop_size_2d[1]}_npy_data_2d')}.csv"
idx_npy_data_2d_file = os.path.join(save_path, idx_npy_data_2d_file)
process_args_list.append(
[
csv_file,
idx_npy_data_3d_file,
idx_npy_data_2d_file,
"/opt/lung/ai",
generate_3d_npy_data_flag,
generate_2d_npy_data_flag,
crop_size_3d,
crop_size_2d,
10,
40,
regular_class_3d,
regular_class_2d,
save_path
]
)
# generate_npy_data_by_all_label_id_df(
# csv_file=csv_file,
# npy_data_3d_file=idx_npy_data_3d_file,
# npy_data_2d_file=idx_npy_data_2d_file,
# dicom_folder="/opt/lung/ai",
# generate_3d_npy_data_flag=generate_3d_npy_data_flag,
# generate_2d_npy_data_flag=generate_2d_npy_data_flag,
# crop_size_3d=crop_size_3d,
# crop_size_2d=crop_size_2d,
# rotate_count=10,
# expand=40,
# regular_class_3d=regular_class_3d,
# regular_class_2d=regular_class_2d,
# save_path=save_path
# )
# 多进程数据
process_count = len(process_args_list)
process_list = []
for idx in range(process_count):
process_args = process_args_list[idx]
idx_process = Process(target=process_npy, args=(process_args,))
idx_process.start()
process_list.append(idx_process)
for idx_process in process_list:
idx_process.join()
return
# # 生成csv数据
# node_time_list = [2046, 2047, 2048, 2060, 2061, 2062, 3001, 4001, 5001, 6001, 1016]
# get_train_data_info_csv(node_time_list=node_time_list)
'''
2021: /df_lung/cls_train_data/csv_data/2021/2021_20241204_094025_rotate_10.csv
2031: /df_lung/cls_train_data/csv_data/2031/2031_20241204_094025_rotate_10.csv
2041: /df_lung/cls_train_data/csv_data/2041/2041_20241204_094026_rotate_10.csv
1010: /df_lung/cls_train_data/csv_data/1010/1010_20241204_093726_rotate_10.csv
1020: /df_lung/cls_train_data/csv_data/1020/1020_20241204_093726_rotate_10.csv
2011: /df_lung/cls_train_data/csv_data/2011/2011_20241204_093726_rotate_10.csv
2046: /df_lung/cls_train_data/csv_data/2046/2046_20241211_155642_rotate_10.csv
2047: /df_lung/cls_train_data/csv_data/2047/2047_20241211_155642_rotate_10.csv
2048: /df_lung/cls_train_data/csv_data/2048/2048_20241211_155642_rotate_10.csv
2060: /df_lung/cls_train_data/csv_data/2060/2060_20241211_155643_rotate_10.csv
2061: /df_lung/cls_train_data/csv_data/2061/2061_20241211_155643_rotate_10.csv
2062: /df_lung/cls_train_data/csv_data/2062/2062_20241211_155643_rotate_10.csv
3001: /df_lung/cls_train_data/csv_data/3001/3001_20241211_155643_rotate_10.csv
4001: /df_lung/cls_train_data/csv_data/4001/4001_20241211_155643_rotate_10.csv
5001: /df_lung/cls_train_data/csv_data/5001/5001_20241211_155643_rotate_10.csv
6001: /df_lung/cls_train_data/csv_data/6001/6001_20241211_155643_rotate_10.csv
1016: /df_lung/cls_train_data/csv_data/1016/1016_20241211_155643_rotate_10.csv
'''
# # 生成npy数据
# node_csv_file_list = [
# (2046, "/df_lung/cls_train_data/csv_data/2046/2046_20241211_155642_rotate_10.csv"),
# (2047, "/df_lung/cls_train_data/csv_data/2047/2047_20241211_155642_rotate_10.csv"),
# (2048, "/df_lung/cls_train_data/csv_data/2048/2048_20241211_155642_rotate_10.csv"),
# (2060, "/df_lung/cls_train_data/csv_data/2060/2060_20241211_155643_rotate_10.csv"),
# (2061, "/df_lung/cls_train_data/csv_data/2061/2061_20241211_155643_rotate_10.csv"),
# (2062, "/df_lung/cls_train_data/csv_data/2062/2062_20241211_155643_rotate_10.csv"),
# (3001, "/df_lung/cls_train_data/csv_data/3001/3001_20241211_155643_rotate_10.csv"),
# (4001, "/df_lung/cls_train_data/csv_data/4001/4001_20241211_155643_rotate_10.csv"),
# (5001, "/df_lung/cls_train_data/csv_data/5001/5001_20241211_155643_rotate_10.csv"),
# (6001, "/df_lung/cls_train_data/csv_data/6001/6001_20241211_155643_rotate_10.csv"),
# (1016, "/df_lung/cls_train_data/csv_data/1016/1016_20241211_155643_rotate_10.csv")
# ]
# crop_size_list = [
# {
# "crop_size_3d": [48, 256, 256],
# "crop_size_2d": [256, 256],
# "generate_3d_npy_data_flag": True,
# "generate_2d_npy_data_flag": True,
# "regular_class_3d": normalize_net_3d,
# "regular_class_2d": normalize_net_2d
# },
# {
# "crop_size_3d": [128, 128, 128],
# "crop_size_2d": [128, 128],
# "generate_3d_npy_data_flag": True,
# "generate_2d_npy_data_flag": False,
# "regular_class_3d": s3d_normalize_3d,
# "regular_class_2d": None
# },
# {
# "crop_size_3d": [48, 280, 280],
# "crop_size_2d": [280, 280],
# "generate_3d_npy_data_flag": False,
# "generate_2d_npy_data_flag": True,
# "regular_class_3d": None,
# "regular_class_2d": d2d_normalize
# }
# ]
# get_npy_data(node_csv_file_list=node_csv_file_list, crop_size_list=crop_size_list)
'''
'''
# 生成训练数据
net_id_list = ["3d", "2d", "2d3d", "s3d", "d2d"]
node_npy_pos_neg_list = [
{
"pos": [2031],
"neg": [2021]
},
{
"pos": [2031],
"neg": [2041]
},
{
"pos": [2031],
"neg": [1010,1020,2011,2021,2041]
},
{
"pos": [2041],
"neg": [1010,1020,2011,2021,2031]
},
{
"pos": [2031],
"neg": [1020, 2011, 2021]
},
{
"pos": [2041],
"neg": [2046, 2047, 2048, 2060, 2061, 2062, 3001, 4001, 5001, 6001, 2021, 2031]
},
{
"pos": [1010, 1020, 1016],
"neg": [2011, 2021, 2031, 2041]
},
{
"pos": [2041],
"neg": [2011, 2021, 2031]
},
]
net_id_crop_size_dict = {
"3d": [48, 256, 256],
"2d": [256, 256],
"2d3d": [48, 256, 256],
"s3d": [128, 128, 128],
"d2d": [280, 280]
}
node_net_id_npy_file_dict = {
(2021, "3d"): ["2021_20241204_094025_rotate_10_48_256_256_npy_data_3d.csv"],
(2021, "2d"): ["2021_20241204_094025_rotate_10_256_256_npy_data_2d.csv"],
(2021, "2d3d"): {
"2d": ["2021_20241204_094025_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2021_20241204_094025_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2021, "s3d"): ["2021_20241204_094025_rotate_10_128_128_128_npy_data_3d.csv"],
(2021, "d2d"): ["2021_20241204_094025_rotate_10_280_280_npy_data_2d.csv"],
(2031, "3d"): ["2031_20241204_094025_rotate_10_48_256_256_npy_data_3d.csv"],
(2031, "2d"): ["2031_20241204_094025_rotate_10_256_256_npy_data_2d.csv"],
(2031, "2d3d"): {
"2d": ["2031_20241204_094025_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2031_20241204_094025_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2031, "s3d"): ["2031_20241204_094025_rotate_10_128_128_128_npy_data_3d.csv"],
(2031, "d2d"): ["2031_20241204_094025_rotate_10_280_280_npy_data_2d.csv"],
(2041, "3d"): ["2041_20241204_094026_rotate_10_48_256_256_npy_data_3d.csv"],
(2041, "2d"): ["2041_20241204_094026_rotate_10_256_256_npy_data_2d.csv"],
(2041, "2d3d"): {
"2d": ["2041_20241204_094026_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2041_20241204_094026_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2041, "s3d"): ["2041_20241204_094026_rotate_10_128_128_128_npy_data_3d.csv"],
(2041, "d2d"): ["2041_20241204_094026_rotate_10_280_280_npy_data_2d.csv"],
(1010, "3d"): ["1010_20241204_093726_rotate_10_48_256_256_npy_data_3d.csv"],
(1010, "2d"): ["1010_20241204_093726_rotate_10_256_256_npy_data_2d.csv"],
(1010, "2d3d"): {
"2d": ["1010_20241204_093726_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["1010_20241204_093726_rotate_10_48_256_256_npy_data_3d.csv"]
},
(1010, "s3d"): ["1010_20241204_093726_rotate_10_128_128_128_npy_data_3d.csv"],
(1010, "d2d"): ["1010_20241204_093726_rotate_10_280_280_npy_data_2d.csv"],
(1020, "3d"): ["1020_20241204_093726_rotate_10_48_256_256_npy_data_3d.csv"],
(1020, "2d"): ["1020_20241204_093726_rotate_10_256_256_npy_data_2d.csv"],
(1020, "2d3d"): {
"2d": ["1020_20241204_093726_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["1020_20241204_093726_rotate_10_48_256_256_npy_data_3d.csv"]
},
(1020, "s3d"): ["1020_20241204_093726_rotate_10_128_128_128_npy_data_3d.csv"],
(1020, "d2d"): ["1020_20241204_093726_rotate_10_280_280_npy_data_2d.csv"],
(2011, "3d"): ["2011_20241204_093726_rotate_10_48_256_256_npy_data_3d.csv"],
(2011, "2d"): ["2011_20241204_093726_rotate_10_256_256_npy_data_2d.csv"],
(2011, "2d3d"): {
"2d": ["2011_20241204_093726_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2011_20241204_093726_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2011, "s3d"): ["2011_20241204_093726_rotate_10_128_128_128_npy_data_3d.csv"],
(2011, "d2d"): ["2011_20241204_093726_rotate_10_280_280_npy_data_2d.csv"],
(2046, "3d"): ["2046_20241211_155642_rotate_10_48_256_256_npy_data_3d.csv"],
(2046, "2d"): ["2046_20241211_155642_rotate_10_256_256_npy_data_2d.csv"],
(2046, "2d3d"): {
"2d": ["2046_20241211_155642_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2046_20241211_155642_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2046, "s3d"): ["2046_20241211_155642_rotate_10_128_128_128_npy_data_3d.csv"],
(2046, "d2d"): ["2046_20241211_155642_rotate_10_280_280_npy_data_2d.csv"],
(2047, "3d"): ["2047_20241211_155642_rotate_10_48_256_256_npy_data_3d.csv"],
(2047, "2d"): ["2047_20241211_155642_rotate_10_256_256_npy_data_2d.csv"],
(2047, "2d3d"): {
"2d": ["2047_20241211_155642_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2047_20241211_155642_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2047, "s3d"): ["2047_20241211_155642_rotate_10_128_128_128_npy_data_3d.csv"],
(2047, "d2d"): ["2047_20241211_155642_rotate_10_280_280_npy_data_2d.csv"],
(2048, "3d"): ["2048_20241211_155642_rotate_10_48_256_256_npy_data_3d.csv"],
(2048, "2d"): ["2048_20241211_155642_rotate_10_256_256_npy_data_2d.csv"],
(2048, "2d3d"): {
"2d": ["2048_20241211_155642_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2048_20241211_155642_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2048, "s3d"): ["2048_20241211_155642_rotate_10_128_128_128_npy_data_3d.csv"],
(2048, "d2d"): ["2048_20241211_155642_rotate_10_280_280_npy_data_2d.csv"],
(2060, "3d"): ["2060_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(2060, "2d"): ["2060_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(2060, "2d3d"): {
"2d": ["2060_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2060_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2060, "s3d"): ["2060_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(2060, "d2d"): ["2060_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(2061, "3d"): ["2061_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(2061, "2d"): ["2061_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(2061, "2d3d"): {
"2d": ["2061_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2061_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2061, "s3d"): ["2061_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(2061, "d2d"): ["2061_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(2062, "3d"): ["2062_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(2062, "2d"): ["2062_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(2062, "2d3d"): {
"2d": ["2062_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["2062_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(2062, "s3d"): ["2062_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(2062, "d2d"): ["2062_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(3001, "3d"): ["3001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(3001, "2d"): ["3001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(3001, "2d3d"): {
"2d": ["3001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["3001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(3001, "s3d"): ["3001_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(3001, "d2d"): ["3001_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(4001, "3d"): ["4001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(4001, "2d"): ["4001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(4001, "2d3d"): {
"2d": ["4001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["4001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(4001, "s3d"): ["4001_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(4001, "d2d"): ["4001_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(5001, "3d"): ["5001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(5001, "2d"): ["5001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(5001, "2d3d"): {
"2d": ["5001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["5001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(5001, "s3d"): ["5001_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(5001, "d2d"): ["5001_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(6001, "3d"): ["6001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(6001, "2d"): ["6001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(6001, "2d3d"): {
"2d": ["6001_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["6001_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(6001, "s3d"): ["6001_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(6001, "d2d"): ["6001_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
(1016, "3d"): ["1016_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"],
(1016, "2d"): ["1016_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
(1016, "2d3d"): {
"2d": ["1016_20241211_155643_rotate_10_256_256_npy_data_2d.csv"],
"3d": ["1016_20241211_155643_rotate_10_48_256_256_npy_data_3d.csv"]
},
(1016, "s3d"): ["1016_20241211_155643_rotate_10_128_128_128_npy_data_3d.csv"],
(1016, "d2d"): ["1016_20241211_155643_rotate_10_280_280_npy_data_2d.csv"],
}
csv_data_dir = "/df_lung/cls_train_data/csv_data"
train_csv_dir = "/df_lung/cls_train_data/train_csv_data"
is_pad_df = True
is_save_csv = False
seed = 100004
generate_train_npy_csv_file(
node_npy_pos_neg_list = node_npy_pos_neg_list,
net_id_list = net_id_list,
net_id_crop_size_dict = net_id_crop_size_dict,
node_net_id_npy_file_dict = node_net_id_npy_file_dict,
csv_data_dir=csv_data_dir,
train_csv_dir=train_csv_dir,
is_pad_df=is_pad_df,
is_save_csv=is_save_csv,
seed=seed
)
import logging
from datetime import datetime
import json
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.schema import Column
from sqlalchemy.types import Integer, String, DateTime, DECIMAL
#生成orm基类
Base = declarative_base()
class AnalysisResult(Base):
#表名
__tablename__ = 'analysis_result'
#设置主键
id = Column(Integer, primary_key=True)
study_id = Column(Integer)
series_id = Column(Integer)
strategy_id = Column(String)
strategy_name = Column(String)
error_msg = Column(String)
status = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
class AnalysisResultSlice(Base):
__tablename__ = 'analysis_result_slice'
id = Column(Integer, primary_key=True)
result_id = Column(Integer)
study_id = Column(Integer)
series_id = Column(Integer)
dicom_id = Column(String)
status = Column(Integer)
meta = Column(String)
indexs = Column(Integer)
z_index = Column(String)
created_time = Column(DateTime)
updated_time = Column(DateTime)
class AlchemyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj.__class__, DeclarativeMeta):
# an SQLAlchemy class
fields = {}
for field in [x for x in dir(obj) if not x.startswith('_') and x != 'metadata']:
data = obj.__getattribute__(field)
try:
if isinstance(data, datetime):
data = data.strftime('%Y-%m-%d %H:%M:%S')
json.dumps(data) # this will fail on non-encodable values, like other classes
fields[field] = data
except TypeError:
fields[field] = None
# a json-encodable dict
return fields
return json.JSONEncoder.default(self, obj)
def new_alchemy_encoder():
_visited_objs = []
class AlchemyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj.__class__, DeclarativeMeta):
# don't re-visit self
if obj in _visited_objs:
return None
_visited_objs.append(obj)
# an SQLAlchemy class
fields = {}
for field in [x for x in dir(obj) if not x.startswith('_') and x != 'metadata']:
data = obj.__getattribute__(field)
try:
if isinstance(data, datetime):
data = data.strftime('%Y-%m-%d %H:%M:%S')
json.dumps(data) # this will fail on non-encodable values, like other classes
fields[field] = data
except Exception:
fields[field] = None
return fields
return json.JSONEncoder.default(self, obj)
return AlchemyEncoder
class DicomSeries(Base):
__tablename__ = 'dicom_file_series'
id = Column(Integer, primary_key=True)
patient_info_id = Column(Integer)
study_id = Column(Integer)
patient_id = Column(String)
patient_name = Column(String)
patient_sex = Column(String)
patient_age = Column(Integer)
series_instance_uid = Column(String)
study_uid = Column(String)
series_number = Column(String)
accession_number = Column(String)
type = Column(String)
status = Column(Integer)
error_msg = Column(String)
dicom_count = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
acq_date = Column(DateTime)
spacing_between_slices = Column(String)
pixel_spacing_x = Column(String)
pixel_spacing_y = Column(String)
rows = Column(Integer)
columns = Column(Integer)
folder_name = Column(String)
class DicomFile(Base):
__tablename__ = 'dicom_file'
id = Column(Integer, primary_key=True)
name = Column(String)
series_id = Column(Integer)
patient_id = Column(String)
patient_name = Column(String)
patient_sex = Column(String)
patient_birthday = Column(DateTime)
patient_age = Column(Integer)
manufacturer = Column(String)
study_uid = Column(String)
study_date = Column(DateTime)
acq_date = Column(DateTime)
study_id = Column(String)
study_desc = Column(String)
series_uid = Column(String)
series_date = Column(DateTime)
series_number = Column(String)
series_modality = Column(String)
series_institution = Column(String)
series_desc = Column(String)
series_exam_part = Column(String)
sop_uid = Column(String)
transfer_syntax = Column(String)
instance_number = Column(Integer)
rows = Column(Integer)
columns = Column(Integer)
slice_location = Column(String)
slice_thickness = Column(String)
kvp = Column(String)
xray_tube_current = Column(String)
pixel_spacing = Column(String)
spacing_between_slices = Column(String)
image_position = Column(String)
image_orientation = Column(String)
patient_position = Column(String)
accession_number = Column(String)
full_str = Column(String)
byte_url = Column(String)
image_url = Column(String)
status = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
class UploadFile(Base):
__tablename__ = 'upload_file'
id = Column(Integer, primary_key=True)
file_name = Column(String)
file_original_name = Column(String)
user_id = Column(Integer)
user_name = Column(String)
file_url = Column(String)
file_path = Column(String)
batch_uuid = Column(String)
file_md5 = Column(String)
parrent_id = Column(Integer)
is_parrent = Column(Integer)
dicom_bundle_id = Column(String)
patient_id = Column(String)
patient_name = Column(String)
status = Column(Integer)
error_msg = Column(String)
acq_date = Column(DateTime)
created_time = Column(DateTime)
updated_time = Column(DateTime)
class DicomStudy(Base):
__tablename__ = 'dicom_file_study'
id = Column(Integer, primary_key=True)
patient_info_id = Column(Integer)
patient_sex = Column(String)
patient_age = Column(Integer)
file_size = Column(DECIMAL)
dicom_count = Column(Integer)
study_uid = Column(String)
folder_name = Column(String)
accession_number = Column(String)
status = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
acq_date = Column(DateTime)
error_msg = Column(String)
class PatientInfo(Base):
__tablename__ = 'patient_info'
id = Column(Integer, primary_key=True)
patient_sex = Column(String)
patient_age = Column(Integer)
patient_id = Column(String)
patient_name = Column(String)
series_institution = Column(String)
status = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
class AnalysisStrategy(Base):
__tablename__ = 'analysis_strategy'
id = Column(Integer, primary_key=True)
name = Column(String)
description = Column(String)
conf_json = Column(String)
status = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
class UserLabel(Base):
__tablename__ = 'user_label'
id = Column(Integer, primary_key=True)
user_id = Column(Integer)
study_id = Column(Integer)
series_id = Column(Integer)
result_id = Column(Integer)
label_name = Column(String)
color = Column(String)
box_count = Column(Integer)
status = Column(Integer)
node_time = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
deleted_time = Column(DateTime)
box_info = Column(String)
area = Column(String)
class UserLabelDelineation(Base):
__tablename__ = 'user_label_delineation'
id = Column(Integer, primary_key=True)
user_id = Column(Integer)
label_id = Column(Integer)
study_id = Column(Integer)
series_id = Column(Integer)
dicom_id = Column(Integer)
status = Column(Integer)
meta = Column(String)
indexs = Column(String)
z_index = Column(Integer)
created_time = Column(DateTime)
updated_time = Column(DateTime)
contour = Column(String)
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_modules(modules):
for m in modules:
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def init_modules_2d(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def extend_mask(mask, kernel_size=(1, 7, 7), padding=(0, 3, 3), min_value=0.01):
mask = F.avg_pool3d(mask.float(), kernel_size=kernel_size, padding=padding, stride=1)
#将取完平均池化后的结果与0.01相比较,如果>=0.01则将其值变为1,否则为0
mask = mask.ge(min_value)
return mask
def std_mean(data, kernel_size=(9, 9)):
result = torch.empty((data.size(0), data.size(1), data.size(2), data.size(3), data.size(4),
kernel_size[0] * kernel_size[1]), device=data.device)
pdata = F.pad(data, [kernel_size[1] // 2, kernel_size[1] // 2, kernel_size[0] // 2, kernel_size[0] // 2, 0, 0],
mode='constant', value=0)
for w in range(0, kernel_size[0], 1):
for h in range(0, kernel_size[1], 1):
result[:, :, :, :, :, kernel_size[0] * w + h] = pdata[:, :, :, w:w + data.size(3), h:h + data.size(4)]
std_mean_result = torch.std_mean(result, dim=5)
return std_mean_result
def get_data(x, mask=None):
x = torch.tensor(x)
std_mean_result_k5 = std_mean(x, kernel_size=(5, 5))
std_mean_result_k7 = std_mean(x, kernel_size=(7, 7))
std_mean_result_k9 = std_mean(x, kernel_size=(9, 9))
y = x
data = torch.cat((x, y,
std_mean_result_k5[1], std_mean_result_k5[0],
std_mean_result_k7[1], std_mean_result_k7[0],
std_mean_result_k9[1], std_mean_result_k9[0]), 1)
#将data降维
data = torch.squeeze(data, dim=0)
return data
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_conv_block(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
layers = []
layers.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias))
if bn:
layers.append(nn.InstanceNorm3d(out_planes))
if activation:
layers.append(nn.LeakyReLU(inplace=True))
return nn.Sequential(*layers)
def create_conv_block_k1(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k2(in_planes, out_planes, kernel_size=2, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k3(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
class RfbfBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0):
super(RfbfBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 1)
self.groups = groups
self.group_num = inter_planes // groups
self.droprate = droprate
self.branch1 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 4, 4),
dilation=(1, 2, 2), groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 6, 6),
dilation=(1, 3, 3), groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch5 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 4, 4),
dilation=(1, 2, 2), groups=groups)
)
self.branch6 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 6, 6),
dilation=(1, 3, 3), groups=groups)
)
self.branch7 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 9, 9), padding=(2, 4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch8 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 9, 9), padding=(2, 4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 4, 4),
dilation=(1, 2, 2), groups=groups)
)
def forward(self, x):
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x6 = self.branch6(x)
x7 = self.branch7(x)
x8 = self.branch8(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num],
x5[:, group * self.group_num:(group + 1) * self.group_num],
x6[:, group * self.group_num:(group + 1) * self.group_num],
x7[:, group * self.group_num:(group + 1) * self.group_num],
x8[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
return out
class RfbeBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1):
super(RfbeBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 2)
self.groups = groups
self.group_num = 2 * inter_planes // groups
self.branch1 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=7, stride=stride, padding=3, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=4, dilation=4, groups=groups)
)
self.concat_conv = create_conv_block_k1(8 * inter_planes, out_planes, groups=groups, activation=False)
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
def forward(self, x):
identity = self.shortcut(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
out = self.concat_conv(out)
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class RfbBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1):
super(RfbBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 2)
self.groups = groups
self.group_num = 2 * inter_planes // groups
self.branch1 = nn.Sequential(
create_conv_block_k1(in_planes, 2 * inter_planes, stride=stride, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block_k1(in_planes, inter_planes, groups=groups),
create_conv_block(inter_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block_k1(in_planes, inter_planes, groups=groups),
create_conv_block(inter_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups)
)
self.concat_conv = create_conv_block_k1(6 * inter_planes, out_planes, groups=groups, activation=False)
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
def forward(self, x):
identity = self.shortcut(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
out = self.concat_conv(out)
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class ResBasicBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False):
super(ResBasicBlock3d, self).__init__()
self.conv1 = create_conv_block_k3(in_planes, out_planes, stride=stride, groups=groups)
self.conv2 = create_conv_block_k3(out_planes, out_planes, groups=groups, activation=False)
self.shortcut = None
if stride != 1 or in_planes != out_planes:
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
self.droprate = droprate
self.se = se
if se:
self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4)
self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes)
def forward(self, x):
identity = self.shortcut(x) if self.shortcut is not None else x
out = self.conv1(x)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
out = self.conv2(out)
if self.se:
original_out = out
out = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out, 1)
out = self.fc1(out)
out = F.leaky_relu(out, inplace=True)
out = self.fc2(out)
out = out.sigmoid()
out = out.view(out.size(0), out.size(1), 1, 1, 1)
out = out * original_out
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class UpBlock3d(nn.Module):
def __init__(self, in_planes1, in_planes2, out_planes, groups=1, scale_factor=2):
super(UpBlock3d, self).__init__()
self.scale_factor = scale_factor
self.conv1 = create_conv_block_k3(in_planes1, out_planes, groups=groups, activation=False)
self.conv2 = create_conv_block_k3(in_planes2, out_planes, groups=groups, activation=False)
def forward(self, x1, x2):
if self.scale_factor != 1 and self.scale_factor != (1, 1, 1):
x1 = F.interpolate(x1, scale_factor=self.scale_factor, mode='nearest')
out1 = self.conv1(x1)
out2 = self.conv2(x2)
out = out1 + out2
out = F.leaky_relu(out, inplace=True)
return out
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_conv_block(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
layers = []
layers.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias))
if bn:
layers.append(nn.InstanceNorm3d(out_planes))
if activation:
layers.append(nn.LeakyReLU(inplace=True))
return nn.Sequential(*layers)
def create_conv_block_k1(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k2(in_planes, out_planes, kernel_size=2, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k3(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
class RfbfBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0):
super(RfbfBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 1)
self.groups = groups
self.group_num = inter_planes // groups
self.droprate = droprate
self.branch1 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch5 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch6 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch7 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 9, 9), padding=(1, 4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch8 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(1, 9, 9), padding=(1, 4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(1, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
def forward(self, x):
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x6 = self.branch6(x)
x7 = self.branch7(x)
x8 = self.branch8(x)
#print(x1.shape, x2.shape, x3.shape, x4.shape, x5.shape, x6.shape, x7.shape, x8.shape)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num],
x5[:, group * self.group_num:(group + 1) * self.group_num],
x6[:, group * self.group_num:(group + 1) * self.group_num],
x7[:, group * self.group_num:(group + 1) * self.group_num],
x8[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
return out
class ResBasicBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False):
super(ResBasicBlock3d, self).__init__()
self.conv1 = create_conv_block_k3(in_planes, out_planes, stride=stride, groups=groups)
self.conv2 = create_conv_block_k3(out_planes, out_planes, groups=groups, activation=False)
self.shortcut = None
if stride != 1 or in_planes != out_planes:
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
self.droprate = droprate
self.se = se
if se:
self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4)
self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes)
def forward(self, x):
identity = self.shortcut(x) if self.shortcut is not None else x
out = self.conv1(x)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
out = self.conv2(out)
if self.se:
original_out = out
out = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out, 1)
out = self.fc1(out)
out = F.leaky_relu(out, inplace=True)
out = self.fc2(out)
out = out.sigmoid()
out = out.view(out.size(0), out.size(1), 1, 1, 1)
out = out * original_out
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_conv_block(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
layers = []
layers.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias))
if bn:
layers.append(nn.InstanceNorm3d(out_planes))
if activation:
layers.append(nn.LeakyReLU(inplace=True))
return nn.Sequential(*layers)
def create_conv_block_k1(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k2(in_planes, out_planes, kernel_size=2, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k3(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
class RfbfBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0):
super(RfbfBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 1)
self.groups = groups
self.group_num = inter_planes // groups
self.droprate = droprate
self.branch1 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 4, 4),
dilation=(1, 2, 2), groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 6, 6),
dilation=(1, 3, 3), groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch5 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 4, 4),
dilation=(1, 2, 2), groups=groups)
)
self.branch6 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(3, 7, 7), padding=(1, 3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 6, 6),
dilation=(1, 3, 3), groups=groups)
)
self.branch7 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 9, 9), padding=(2, 4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 2, 2),
dilation=(1, 1, 1), groups=groups)
)
self.branch8 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 9, 9), padding=(2, 4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(3, 5, 5), padding=(1, 4, 4),
dilation=(1, 2, 2), groups=groups)
)
def forward(self, x):
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x6 = self.branch6(x)
x7 = self.branch7(x)
x8 = self.branch8(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num],
x5[:, group * self.group_num:(group + 1) * self.group_num],
x6[:, group * self.group_num:(group + 1) * self.group_num],
x7[:, group * self.group_num:(group + 1) * self.group_num],
x8[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
return out
class RfbeBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1):
super(RfbeBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 2)
self.groups = groups
self.group_num = 2 * inter_planes // groups
self.branch1 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block(in_planes, 2 * inter_planes, kernel_size=7, stride=stride, padding=3, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=4, dilation=4, groups=groups)
)
self.concat_conv = create_conv_block_k1(8 * inter_planes, out_planes, groups=groups, activation=False)
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
def forward(self, x):
identity = self.shortcut(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
out = self.concat_conv(out)
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class RfbBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1):
super(RfbBlock3d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 2)
self.groups = groups
self.group_num = 2 * inter_planes // groups
self.branch1 = nn.Sequential(
create_conv_block_k1(in_planes, 2 * inter_planes, stride=stride, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block_k1(in_planes, inter_planes, groups=groups),
create_conv_block(inter_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block_k1(in_planes, inter_planes, groups=groups),
create_conv_block(inter_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups),
create_conv_block_k3(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups)
)
self.concat_conv = create_conv_block_k1(6 * inter_planes, out_planes, groups=groups, activation=False)
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
def forward(self, x):
identity = self.shortcut(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
out = self.concat_conv(out)
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class ResBasicBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False):
super(ResBasicBlock3d, self).__init__()
self.conv1 = create_conv_block_k3(in_planes, out_planes, stride=stride, groups=groups)
self.conv2 = create_conv_block_k3(out_planes, out_planes, groups=groups, activation=False)
self.shortcut = None
if stride != 1 or in_planes != out_planes:
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
self.droprate = droprate
self.se = se
if se:
self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4)
self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes)
def forward(self, x):
identity = self.shortcut(x) if self.shortcut is not None else x
out = self.conv1(x)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
out = self.conv2(out)
if self.se:
original_out = out
out = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out, 1)
out = self.fc1(out)
out = F.leaky_relu(out, inplace=True)
out = self.fc2(out)
out = out.sigmoid()
out = out.view(out.size(0), out.size(1), 1, 1, 1)
out = out * original_out
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class UpBlock3d(nn.Module):
def __init__(self, in_planes1, in_planes2, out_planes, groups=1, scale_factor=2):
super(UpBlock3d, self).__init__()
self.scale_factor = scale_factor
self.conv1 = create_conv_block_k3(in_planes1, out_planes, groups=groups, activation=False)
self.conv2 = create_conv_block_k3(in_planes2, out_planes, groups=groups, activation=False)
def forward(self, x1, x2):
if self.scale_factor != 1 and self.scale_factor != (1, 1, 1):
x1 = F.interpolate(x1, scale_factor=self.scale_factor, mode='nearest')
out1 = self.conv1(x1)
out2 = self.conv2(x2)
out = out1 + out2
out = F.leaky_relu(out, inplace=True)
return out
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def create_conv_block(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
layers = []
layers.append(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias))
if bn:
layers.append(nn.InstanceNorm2d(out_planes))
if activation:
layers.append(nn.LeakyReLU(inplace=True))
return nn.Sequential(*layers)
def create_conv_block_k1(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k2(in_planes, out_planes, kernel_size=2, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k3(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
class RfbfBlock2d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0):
super(RfbfBlock2d, self).__init__()
inter_planes = max(int(np.ceil(out_planes / 8)), 1)
self.groups = groups
self.group_num = inter_planes // groups
self.droprate = droprate
self.branch1 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch5 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch6 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch7 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(9, 9), padding=(4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
self.branch8 = nn.Sequential(
create_conv_block(in_planes, inter_planes, kernel_size=(9, 9), padding=(4, 4),
stride=stride, groups=groups),
create_conv_block(inter_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
dilation=(1, 1), groups=groups)
)
def forward(self, x):
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x6 = self.branch6(x)
x7 = self.branch7(x)
x8 = self.branch8(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num],
x5[:, group * self.group_num:(group + 1) * self.group_num],
x6[:, group * self.group_num:(group + 1) * self.group_num],
x7[:, group * self.group_num:(group + 1) * self.group_num],
x8[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
if self.droprate > 0:
out = F.dropout2d(out, p=self.droprate, training=self.training)
return out
class ResBasicBlock2d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False):
super(ResBasicBlock2d, self).__init__()
self.conv1 = create_conv_block_k3(in_planes, out_planes, stride=stride, groups=groups)
self.conv2 = create_conv_block_k3(out_planes, out_planes, groups=groups, activation=False)
self.shortcut = None
if stride != 1 or in_planes != out_planes:
self.shortcut = create_conv_block_k1(in_planes, out_planes, stride=stride, groups=groups, activation=False)
self.droprate = droprate
self.se = se
if se:
self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4)
self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes)
def forward(self, x):
identity = self.shortcut(x) if self.shortcut is not None else x
out = self.conv1(x)
if self.droprate > 0:
out = F.dropout2d(out, p=self.droprate, training=self.training)
out = self.conv2(out)
if self.se:
original_out = out
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.fc1(out)
out = F.leaky_relu(out, inplace=True)
out = self.fc2(out)
out = out.sigmoid()
#这里需要测试一下
out = out.view(out.size(0), out.size(1), 1, 1)
out = out * original_out
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
import torch
import torch.nn as nn
class BCEFocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.5, reduction='elementwise_mean') -> None:
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, _input, target):
pt = torch.sigmoid(_input)
alpha = self.alpha
\ No newline at end of file
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from net.component_i import RfbfBlock3d, ResBasicBlock3d
from net.component_i import create_conv_block_k3
from net.component_c import init_modules
#用于测试
from cls_utils.data import test_save_ckpt
class FPN(nn.Module):
def __init__(self, n_channels, n_base_filters, groups=1):
super(FPN, self).__init__()
self.down_level1 = nn.Sequential(
RfbfBlock3d(n_channels, n_base_filters, groups=groups)
)
self.down_level2 = nn.Sequential(
ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
)
self.down_level3 = nn.Sequential(
ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
)
self.down_level4 = nn.Sequential(
ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
)
self.down_level5 = nn.Sequential(
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
)
self.down_level6 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level7 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level8 = nn.Sequential(
create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
)
def forward(self, x):
down_out1 = self.down_level1(x)
down_out2 = self.down_level2(down_out1)
down_out3 = self.down_level3(down_out2)
down_out4 = self.down_level4(down_out3)
down_out5 = self.down_level5(down_out4)
down_out6 = self.down_level6(down_out5)
down_out7 = self.down_level7(down_out6)
down_out8 = self.down_level8(down_out7)
return down_out8
class Net(nn.Module):
def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
super(Net, self).__init__()
self.fpn = FPN(8 * n_channels, n_base_filters)
self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes)
#初始化模型参数
init_modules(self.modules())
def forward(self, data):
out = self.fpn(data)
out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out_avg, 1)
diff_output = self.diff_classifier(out)
return diff_output
def net_test():
cfg = dict()
cfg['n_channels'] = 1
cfg['n_diff_classes'] = 1
cfg['training_crop_size'] = [48, 256, 256]
cfg['pretrain_ckpt'] = ''
batch_size = 1
x = torch.rand(batch_size, cfg['n_channels'] * 8,
cfg['training_crop_size'][0], cfg['training_crop_size'][1], cfg['training_crop_size'][2])
model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
x = x.to(device)
print(x.shape, x.device)
#加载模型参数
pretrain_ckpt_path = './cls_train/best_cls/cls_model218_1024u_20230407.ckpt'
model_param = torch.load(pretrain_ckpt_path)
model.load_state_dict(model_param['state_dict'])
model = model.to(device)
print('参数加载成功')
model.eval()
with torch.no_grad():
diff_output = model(x)
print(diff_output.shape)
#test_save_ckpt(model=model)
if __name__ == '__main__':
net_test()
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from net.component_i_230425 import RfbfBlock3d, ResBasicBlock3d
from net.component_i_230425 import create_conv_block_k3
from net.component_c import init_modules
#用于测试
class FPN(nn.Module):
def __init__(self, n_channels, n_base_filters, groups=1):
super(FPN, self).__init__()
self.down_level1 = nn.Sequential(
RfbfBlock3d(n_channels, n_base_filters, groups=groups)
)
self.down_level2 = nn.Sequential(
ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
)
self.down_level3 = nn.Sequential(
ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
)
self.down_level4 = nn.Sequential(
ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
)
self.down_level5 = nn.Sequential(
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
)
self.down_level6 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level7 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level8 = nn.Sequential(
create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
)
def forward(self, x):
down_out1 = self.down_level1(x)
down_out2 = self.down_level2(down_out1)
down_out3 = self.down_level3(down_out2)
down_out4 = self.down_level4(down_out3)
down_out5 = self.down_level5(down_out4)
down_out6 = self.down_level6(down_out5)
down_out7 = self.down_level7(down_out6)
down_out8 = self.down_level8(down_out7)
return down_out8
class Net(nn.Module):
def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
super(Net, self).__init__()
self.fpn = FPN(n_channels, n_base_filters)
self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes)
#初始化模型参数
init_modules(self.modules())
def forward(self, data):
out = self.fpn(data)
#改成最大池化
out_max = F.adaptive_max_pool3d(out, (1, 1, 1))
#out_max = F.max_pool3d(out, (1, 1, 1))
#out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out_max, 1)
#print(out_avg.shape)
#print(out_max.shape)
diff_output = self.diff_classifier(out)
return diff_output
def net_test():
cfg = dict()
cfg['n_channels'] = 1
cfg['n_diff_classes'] = 1
cfg['training_crop_size'] = [48, 256, 256]
cfg['pretrain_ckpt'] = ''
batch_size = 1
x = torch.rand(batch_size, cfg['n_channels'],
cfg['training_crop_size'][0], cfg['training_crop_size'][1], cfg['training_crop_size'][2])
print(x.shape)
model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
x = x.to(device)
print(x.shape, x.device)
#加载模型参数
"""pretrain_ckpt_path = './cls_train/best_cls/cls_model218_1024u_20230407.ckpt'
model_param = torch.load(pretrain_ckpt_path)
model.load_state_dict(model_param['state_dict'])"""
model = model.to(device)
#print('参数加载成功')
model.eval()
with torch.no_grad():
diff_output = model(x)
print(diff_output.shape)
#test_save_ckpt(model=model)
if __name__ == '__main__':
net_test()
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from net.component_i_231025 import RfbfBlock3d, ResBasicBlock3d
from net.component_i_231025 import create_conv_block_k3
from net.component_c import init_modules
#用于测试
class FPN(nn.Module):
def __init__(self, n_channels, n_base_filters, groups=1):
super(FPN, self).__init__()
self.down_level1 = nn.Sequential(
RfbfBlock3d(n_channels, n_base_filters, groups=groups)
)
self.down_level2 = nn.Sequential(
ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
)
self.down_level3 = nn.Sequential(
ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
)
self.down_level4 = nn.Sequential(
ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
)
self.down_level5 = nn.Sequential(
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
)
self.down_level6 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level7 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level8 = nn.Sequential(
create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
)
def forward(self, x):
down_out1 = self.down_level1(x)
down_out2 = self.down_level2(down_out1)
down_out3 = self.down_level3(down_out2)
down_out4 = self.down_level4(down_out3)
down_out5 = self.down_level5(down_out4)
down_out6 = self.down_level6(down_out5)
down_out7 = self.down_level7(down_out6)
down_out8 = self.down_level8(down_out7)
return down_out8
class Net(nn.Module):
def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
super(Net, self).__init__()
self.fpn = FPN(n_channels, n_base_filters)
self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes)
#初始化模型参数
init_modules(self.modules())
def forward(self, data):
out = self.fpn(data)
#改成最大池化
out_max = F.adaptive_max_pool3d(out, (1, 1, 1))
#out_max = F.max_pool3d(out, (1, 1, 1))
#out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out_max, 1)
#print(out_avg.shape)
#print(out_max.shape)
diff_output = self.diff_classifier(out)
diff_output = F.sigmoid(diff_output)
return diff_output
def net_test():
cfg = dict()
cfg['n_channels'] = 1
cfg['n_diff_classes'] = 1
cfg['training_crop_size'] = [48, 256, 256]
cfg['pretrain_ckpt'] = ''
batch_size = 1
x = torch.rand(batch_size, cfg['n_channels'],
cfg['training_crop_size'][0], cfg['training_crop_size'][1], cfg['training_crop_size'][2])
print(x.shape)
model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes'))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
x = x.to(device)
print(x.shape, x.device)
#加载模型参数
"""pretrain_ckpt_path = './cls_train/best_cls/cls_model218_1024u_20230407.ckpt'
model_param = torch.load(pretrain_ckpt_path)
model.load_state_dict(model_param['state_dict'])"""
model = model.to(device)
#print('参数加载成功')
model.eval()
with torch.no_grad():
diff_output = model(x)
print(diff_output.shape)
dummy_input = torch.randn(1, 1, 48, 256, 256).to(device) # 根据实际输入尺寸调整
torch.onnx.export(model,
dummy_input,
'cls_train_net_cls_1024u_231025_20241113.onnx',
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
if __name__ == '__main__':
net_test()
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from net.component_i_2d import RfbfBlock2d, ResBasicBlock2d
from net.component_i_2d import create_conv_block_k3
from net.component_c import init_modules_2d
#用于测试
from cls_utils.data import test_save_ckpt
class FPN(nn.Module):
def __init__(self, n_channels, n_base_filters, groups=1):
super(FPN, self).__init__()
self.down_level1 = nn.Sequential(
RfbfBlock2d(n_channels, n_base_filters, groups=groups)
)
self.down_level2 = nn.Sequential(
ResBasicBlock2d(1 * n_base_filters, 2 * n_base_filters, stride=(2, 2), groups=groups),
ResBasicBlock2d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
)
self.down_level3 = nn.Sequential(
ResBasicBlock2d(2 * n_base_filters, 4 * n_base_filters, stride=(2, 2), groups=groups),
ResBasicBlock2d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
)
self.down_level4 = nn.Sequential(
ResBasicBlock2d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
)
self.down_level5 = nn.Sequential(
ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock2d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
)
self.down_level6 = nn.Sequential(
ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level7 = nn.Sequential(
ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level8 = nn.Sequential(
create_conv_block_k3(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
)
def forward(self, x):
down_out1 = self.down_level1(x)
down_out2 = self.down_level2(down_out1)
down_out3 = self.down_level3(down_out2)
down_out4 = self.down_level4(down_out3)
down_out5 = self.down_level5(down_out4)
down_out6 = self.down_level6(down_out5)
down_out7 = self.down_level7(down_out6)
down_out8 = self.down_level8(down_out7)
return down_out8
class Net(nn.Module):
def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
super(Net, self).__init__()
self.fpn = FPN(8 * n_channels, n_base_filters)
self.diff_classifier = nn.Linear(16 * n_base_filters, n_diff_classes)
#初始化模型参数
init_modules_2d(self.modules())
def forward(self, data):
out = self.fpn(data)
out_avg = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out_avg, 1)
diff_output = self.diff_classifier(out)
return diff_output
def net_test():
cfg = dict()
cfg['n_channels'] = 1
cfg['n_diff_classes'] = 1
cfg['training_crop_size'] = [128, 128]
cfg['pretrain_ckpt'] = ''
batch_size = 1
x = torch.rand(batch_size, cfg['n_channels'] * 8,
cfg['training_crop_size'][0], cfg['training_crop_size'][1])
print('x.shape:', x.shape)
model = Net(n_channels=cfg.get('n_channels'), n_diff_classes=cfg.get('n_diff_classes'))
print('模型结构:')
print(model)
print('------------------------------------------')
"""
print(type(model))
for layer in model:
#X = layer(x)
print(layer.__class__.__name__, f'output size: ')"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
x = x.to(device)
print(x.shape, x.device)
#加载模型参数
model = model.to(device)
#print('参数加载成功')
model.eval()
with torch.no_grad():
diff_output = model(x)
print(diff_output.shape)
#test_save_ckpt(model=model)
if __name__ == '__main__':
net_test()
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class Normalize(nn.Module):
def __init__(self):
super(Normalize, self).__init__()
def forward(self, data, min_value=-1000, max_value=600):
new_data = data
new_data[new_data < min_value] = min_value
new_data[new_data > max_value] = max_value
# normalize to [-1, 1]
new_data = 2.0 * (new_data - min_value) / (max_value - min_value) - 1
return new_data
import argparse
import copy
import os, sys
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != os.path.basename(current_dir):
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
os.environ["OMP_NUM_THREADS"] = "1"
from cls_utils.log_utils import get_logger
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from pytorch_train.classification_model_2d3d import Net2d3d, Net2d, Net3d, init_modules
from data.dataset_2d3d import ClassificationDataset2d3d, ClassificationDataset3d, ClassificationDataset2d, custom_collate_fn_2d, custom_collate_fn_3d, custom_collate_fn_2d3d, cls_report_dict_to_string, ClassificationDatasetError2d, ClassificationDatasetError3d, ClassificationDatasetError2d3d, custom_collate_fn_2d_error, custom_collate_fn_3d_error, custom_collate_fn_2d3d_error
from pytorch_train.encoder_cls import S3dClassifier as NetS3d
from pytorch_train.encoder_cls import D2dClassifier as NetD2d
from data.dataset_2d3d import ClassificationDatasetS3d, ClassificationDatasetErrorS3d, custom_collate_fn_s3d, custom_collate_fn_s3d_error, ClassificationDatasetD2d, ClassificationDatasetErrorD2d, custom_collate_fn_d2d, custom_collate_fn_d2d_error
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import multiprocessing as mp
import traceback
import pandas as pd
from pathlib import Path
from torch.optim import lr_scheduler, Adam
import random
import numpy as np
def set_seed(seed=1004):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_model_params(model):
params = ""
with torch.no_grad():
for name, param in model.named_parameters():
if "net2d.diff_classifier.weight" in name or "net3d.diff_classifier.weight" in name:
params += f"{name}\n{param[0][:20]}\n\n"
return params
def test_on_single_gpu(init_model, config, device_index, train_2d3d_data_2d_csv_file, val_2d3d_data_2d_csv_file, test_2d3d_data_2d_csv_file, train_2d3d_data_3d_csv_file, val_2d3d_data_3d_csv_file, test_2d3d_data_3d_csv_file, train_csv_file, val_csv_file, test_csv_file):
logger = config['logger']
net_id = config['net_id']
save_dir = config['save_dir']
current_device = torch.device(f"cuda:{device_index}")
test_train_dataset_info_dict = config['test_train_dataset_info_dict']
test_val_dataset_info_dict = config['test_val_dataset_info_dict']
test_test_dataset_info_dict = config['test_test_dataset_info_dict']
train_batch_size = config['train_batch_size']
val_batch_size = config['val_batch_size']
test_batch_size = config['test_batch_size']
num_workers = config['num_workers']
if net_id == "2d":
train_dataset = ClassificationDataset2d(train_csv_file, data_info="train_dataset_2d")
val_dataset = ClassificationDataset2d(val_csv_file, data_info="val_dataset_2d")
test_dataset = ClassificationDataset2d(test_csv_file, data_info="test_dataset_2d")
custom_collate_fn = custom_collate_fn_2d
elif net_id == "3d":
train_dataset = ClassificationDataset3d(train_csv_file, data_info="train_dataset_3d")
val_dataset = ClassificationDataset3d(val_csv_file, data_info="val_dataset_3d")
test_dataset = ClassificationDataset3d(test_csv_file, data_info="test_dataset_3d")
custom_collate_fn = custom_collate_fn_3d
elif net_id == "2d3d":
train_dataset = ClassificationDataset2d3d(train_2d3d_data_2d_csv_file, train_2d3d_data_3d_csv_file, data_info="train_dataset_2d3d")
val_dataset = ClassificationDataset2d3d(val_2d3d_data_2d_csv_file, val_2d3d_data_3d_csv_file, data_info="val_dataset_2d3d")
test_dataset = ClassificationDataset2d3d(test_2d3d_data_2d_csv_file, test_2d3d_data_3d_csv_file, data_info="test_dataset_2d3d")
custom_collate_fn = custom_collate_fn_2d3d
elif net_id == "s3d":
train_dataset = ClassificationDatasetS3d(train_csv_file, data_info="train_dataset_S3d")
val_dataset = ClassificationDatasetS3d(val_csv_file, data_info="val_dataset_S3d")
test_dataset = ClassificationDatasetS3d(test_csv_file, data_info="test_dataset_s3d")
custom_collate_fn = custom_collate_fn_s3d
elif net_id == "d2d":
train_dataset = ClassificationDatasetD2d(train_csv_file, data_info="train_dataset_d2d")
val_dataset = ClassificationDatasetD2d(val_csv_file, data_info="val_dataset_d2d")
test_dataset = ClassificationDatasetD2d(test_csv_file, data_info="test_dataset_d2d")
custom_collate_fn = custom_collate_fn_d2d
else:
raise ValueError(f"net_id {net_id} not supported")
logger.info(f"before start test, train_dataset: {len(train_dataset)}\n val_dataset: {len(val_dataset)}\n test_dataset: {len(test_dataset)}")
train_data_loader = DataLoader(
train_dataset,
batch_size=train_batch_size,
drop_last=False,
shuffle=False,
num_workers=num_workers,
collate_fn=custom_collate_fn
)
val_data_loader = DataLoader(
val_dataset,
batch_size=val_batch_size,
drop_last=False,
shuffle=False,
num_workers=num_workers,
collate_fn=custom_collate_fn
)
test_data_loader = DataLoader(
test_dataset,
batch_size=test_batch_size,
drop_last=False,
shuffle=False,
num_workers=num_workers,
collate_fn=custom_collate_fn
)
logger.info(f"local rank: {device_index}, train data total batch: {len(train_data_loader)}, val data total batch: {len(val_data_loader)}, test data total batch: {len(test_data_loader)}")
params = get_model_params(copy.deepcopy(init_model))
logger.info(f"local rank: {device_index}\nbefore start test, load init, params:\n{params}")
# 评测训练集
torch.cuda.empty_cache()
test_train_dataset_flag = test_train_dataset_info_dict["test_train_dataset_flag"]
if test_train_dataset_flag:
test_train_epoch_list = test_train_dataset_info_dict["test_train_epoch_list"]
test_train_epoch_list = sorted(test_train_epoch_list)
logger.info(f"local rank: {device_index}, start test_[train]_dataset")
for idx_test_train_epoch in test_train_epoch_list:
idx_epoch_pt_file = f"train_epoch_{idx_test_train_epoch}_net_{net_id}.pt"
idx_epoch_pt_file = os.path.join(save_dir, idx_epoch_pt_file)
model = copy.deepcopy(init_model)
model.load_state_dict(torch.load(idx_epoch_pt_file, map_location="cpu"))
model = model.to(current_device)
logger.info(f"local rank: {device_index}, test_[train]_dataset, start test_[train]_dataset epoch: {idx_test_train_epoch}")
logger.info(f"local rank: {device_index}, test_[train]_dataset, after load epoch {idx_test_train_epoch}\nparams:\n{get_model_params(model)}")
model.eval()
train_all_label_id = []
train_all_label_id_2d = []
train_all_label_id_3d = []
train_all_z_index = []
train_all_z_index_2d = []
train_all_z_index_3d = []
train_all_rotate_count = []
train_all_rotate_count_2d = []
train_all_rotate_count_3d = []
train_all_labels = []
train_all_labels_2d = []
train_all_labels_3d = []
train_all_preds = []
train_all_preds_2d = []
train_all_preds_3d = []
model.eval()
with torch.no_grad():
logger.info(f"local rank: {device_index}, test_[train]_dataset, start test_[train]_dataset epoch: {idx_test_train_epoch}, train_data_loader: {len(train_data_loader)}")
for step, batch_data in enumerate(train_data_loader):
if net_id == "2d3d":
label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data
train_all_label_id_2d.extend(label_id_2d)
train_all_label_id_3d.extend(label_id_3d)
train_all_z_index_2d.extend(z_index_2d)
train_all_z_index_3d.extend(z_index_3d)
train_all_rotate_count_2d.extend(rotate_count_2d)
train_all_rotate_count_3d.extend(rotate_count_3d)
train_all_labels_2d.append(label_2d.cpu())
train_all_labels_3d.append(label_3d.cpu())
data_2d = data_2d.to(current_device)
data_3d = data_3d.to(current_device)
label_2d = label_2d.to(current_device)
label_3d = label_3d.to(current_device)
y_pred_2d, y_pred_3d = model(data_2d, data_3d)
train_all_preds_2d.append(y_pred_2d.detach().cpu())
train_all_preds_3d.append(y_pred_3d.detach().cpu())
else:
label_id, data, label, z_index, rotate_count = batch_data
train_all_label_id.extend(label_id)
train_all_z_index.extend(z_index)
train_all_rotate_count.extend(rotate_count)
train_all_labels.append(label.cpu())
data = data.to(current_device)
label = label.to(current_device)
y_pred = model(data)
train_all_preds.append(y_pred.detach().cpu())
logger.info(f"local rank: {device_index}, test_[train]_dataset, step: {step}, train_data_loader: {len(train_data_loader)}")
if net_id == "2d3d":
del model, data_2d, data_3d, label_2d, label_3d, y_pred_2d, y_pred_3d, z_index_2d, z_index_3d, rotate_count_2d, rotate_count_3d
else:
del model, data, label, y_pred, z_index, rotate_count
torch.cuda.empty_cache()
if net_id == "2d3d":
train_all_label_id_2d = train_all_label_id_2d[:]
train_all_label_id_3d = train_all_label_id_3d[:]
train_all_net_type_2d = ["2d_net"] * len(train_all_label_id_2d)
train_all_net_type_3d = ["3d_net"] * len(train_all_label_id_3d)
train_all_labels_2d = torch.cat(train_all_labels_2d, dim=0)
train_all_labels_3d = torch.cat(train_all_labels_3d, dim=0)
train_all_preds_2d = torch.cat(train_all_preds_2d, dim=0)
train_all_preds_3d = torch.cat(train_all_preds_3d, dim=0)
concat_label_id = train_all_label_id_2d + train_all_label_id_3d
concat_z_index = train_all_z_index_2d[:] + train_all_z_index_3d[:]
concat_rotate_count = train_all_rotate_count_2d[:] + train_all_rotate_count_3d[:]
concat_net_type = train_all_net_type_2d + train_all_net_type_3d
concat_label = torch.cat([train_all_labels_2d, train_all_labels_3d], dim=0)
concat_pred = torch.cat([train_all_preds_2d, train_all_preds_3d], dim=0)
logger.info(f"local rank: {device_index}, test_[train]_dataset, concat_label_id: {len(concat_label_id)}, concat_label: {concat_label.shape}, concat_pred: {concat_pred.shape}")
lens = len(concat_label_id)
idx_test_train_epoch_df = pd.DataFrame({
"test_type": ["train_dataset"] * lens,
"local_rank": [device_index] * lens,
"net_id": [net_id] * lens,
"net_type": concat_net_type,
"epoch": [idx_test_train_epoch] * lens,
"label_id": concat_label_id,
"z_index": concat_z_index,
"rotate_count": concat_rotate_count,
"label": concat_label.tolist(),
"pred": concat_pred.tolist()
})
idx_test_train_epoch_df_file = f"local_rank_{device_index}_test_type_train_dataset_epoch_{idx_test_train_epoch}_net_{net_id}.csv"
idx_test_train_epoch_df_file = os.path.join(save_dir, idx_test_train_epoch_df_file)
idx_test_train_epoch_df.to_csv(idx_test_train_epoch_df_file, index=False, encoding="utf-8")
logger.info(f"local rank: {device_index}, test_[train]_dataset epoch: {idx_test_train_epoch}, save to {idx_test_train_epoch_df_file}")
else:
train_all_label_id = train_all_label_id[:]
train_all_labels = torch.cat(train_all_labels, dim=0)
train_all_preds = torch.cat(train_all_preds, dim=0)
lens = len(train_all_label_id)
idx_test_train_epoch_df = pd.DataFrame({
"test_type": ["train_dataset"] * lens,
"local_rank": [device_index] * lens,
"net_id": [net_id] * lens,
"net_type": [f"{net_id}_net"] * lens,
"epoch": [idx_test_train_epoch] * lens,
"label_id": train_all_label_id,
"z_index": train_all_z_index[:],
"rotate_count": train_all_rotate_count[:],
"label": train_all_labels.tolist(),
"pred": train_all_preds.tolist()
})
idx_test_train_epoch_df_file = f"local_rank_{device_index}_test_type_train_dataset_epoch_{idx_test_train_epoch}_net_{net_id}.csv"
idx_test_train_epoch_df_file = os.path.join(save_dir, idx_test_train_epoch_df_file)
idx_test_train_epoch_df.to_csv(idx_test_train_epoch_df_file, index=False, encoding="utf-8")
logger.info(f"local rank: {device_index}, test_[train]_dataset epoch: {idx_test_train_epoch}, save to {idx_test_train_epoch_df_file}")
logger.info(f"local rank: {device_index}, test_[train]_dataset finished")
# 评测验证集
torch.cuda.empty_cache()
test_val_dataset_flag = test_val_dataset_info_dict["test_val_dataset_flag"]
if test_val_dataset_flag:
test_val_epoch_list = test_val_dataset_info_dict["test_val_epoch_list"]
test_val_epoch_list = sorted(test_val_epoch_list)
logger.info(f"local rank: {device_index}, start test_[val]_dataset")
for idx_test_val_epoch in test_val_epoch_list:
idx_epoch_pt_file = f"train_epoch_{idx_test_val_epoch}_net_{net_id}.pt"
idx_epoch_pt_file = os.path.join(save_dir, idx_epoch_pt_file)
model = copy.deepcopy(init_model)
model.load_state_dict(torch.load(idx_epoch_pt_file, map_location="cpu"))
model = model.to(current_device)
logger.info(f"local rank: {device_index}, test_[val]_dataset, start test_[val]_dataset epoch: {idx_test_val_epoch}")
logger.info(f"local rank: {device_index}, test_[val]_dataset, after load epoch {idx_test_val_epoch}\nparams:\n{get_model_params(model)}")
model.eval()
val_all_label_id = []
val_all_label_id_2d = []
val_all_label_id_3d = []
val_all_z_index = []
val_all_z_index_2d = []
val_all_z_index_3d = []
val_all_rotate_count = []
val_all_rotate_count_2d = []
val_all_rotate_count_3d = []
val_all_labels = []
val_all_labels_2d = []
val_all_labels_3d = []
val_all_preds = []
val_all_preds_2d = []
val_all_preds_3d = []
model.eval()
with torch.no_grad():
logger.info(f"local rank: {device_index}, test_[val]_dataset, start test_[val]_dataset epoch: {idx_test_val_epoch}, val_data_loader: {len(val_data_loader)}")
for step, batch_data in enumerate(val_data_loader):
if net_id == "2d3d":
label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data
val_all_label_id_2d.extend(label_id_2d)
val_all_label_id_3d.extend(label_id_3d)
val_all_z_index_2d.extend(z_index_2d)
val_all_z_index_3d.extend(z_index_3d)
val_all_rotate_count_2d.extend(rotate_count_2d)
val_all_rotate_count_3d.extend(rotate_count_3d)
val_all_labels_2d.append(label_2d.cpu())
val_all_labels_3d.append(label_3d.cpu())
data_2d = data_2d.to(current_device)
data_3d = data_3d.to(current_device)
label_2d = label_2d.to(current_device)
label_3d = label_3d.to(current_device)
y_pred_2d, y_pred_3d = model(data_2d, data_3d)
val_all_preds_2d.append(y_pred_2d.detach().cpu())
val_all_preds_3d.append(y_pred_3d.detach().cpu())
else:
label_id, data, label, z_index, rotate_count = batch_data
val_all_label_id.extend(label_id)
val_all_z_index.extend(z_index)
val_all_rotate_count.extend(rotate_count)
val_all_labels.append(label.cpu())
data = data.to(current_device)
label = label.to(current_device)
y_pred = model(data)
val_all_preds.append(y_pred.detach().cpu())
logger.info(f"local rank: {device_index}, test_[val]_dataset, step: {step}, val_data_loader: {len(val_data_loader)}")
if net_id == "2d3d":
del model, data_2d, data_3d, label_2d, label_3d, y_pred_2d, y_pred_3d, z_index_2d, z_index_3d, rotate_count_2d, rotate_count_3d
else:
del model, data, label, y_pred, z_index, rotate_count
torch.cuda.empty_cache()
if net_id == "2d3d":
val_all_label_id_2d = val_all_label_id_2d[:]
val_all_label_id_3d = val_all_label_id_3d[:]
val_all_net_type_2d = ["2d_net"] * len(val_all_label_id_2d)
val_all_net_type_3d = ["3d_net"] * len(val_all_label_id_3d)
val_all_labels_2d = torch.cat(val_all_labels_2d, dim=0)
val_all_labels_3d = torch.cat(val_all_labels_3d, dim=0)
val_all_preds_2d = torch.cat(val_all_preds_2d, dim=0)
val_all_preds_3d = torch.cat(val_all_preds_3d, dim=0)
concat_label_id = val_all_label_id_2d + val_all_label_id_3d
concat_z_index = val_all_z_index_2d[:] + val_all_z_index_3d[:]
concat_rotate_count = val_all_rotate_count_2d[:] + val_all_rotate_count_3d[:]
concat_net_type = val_all_net_type_2d + val_all_net_type_3d
concat_label = torch.cat([val_all_labels_2d, val_all_labels_3d], dim=0)
concat_pred = torch.cat([val_all_preds_2d, val_all_preds_3d], dim=0)
logger.info(f"local rank: {device_index}, test_[val]_dataset, concat_label_id: {len(concat_label_id)}, concat_label: {concat_label.shape}, concat_pred: {concat_pred.shape}")
lens = len(concat_label_id)
idx_test_val_epoch_df = pd.DataFrame({
"test_type": ["val_dataset"] * lens,
"local_rank": [device_index] * lens,
"net_id": [net_id] * lens,
"net_type": concat_net_type,
"epoch": [idx_test_val_epoch] * lens,
"label_id": concat_label_id,
"z_index": concat_z_index,
"rotate_count": concat_rotate_count,
"label": concat_label.tolist(),
"pred": concat_pred.tolist()
})
idx_test_val_epoch_df_file = f"local_rank_{device_index}_test_type_val_dataset_epoch_{idx_test_val_epoch}_net_{net_id}.csv"
idx_test_val_epoch_df_file = os.path.join(save_dir, idx_test_val_epoch_df_file)
idx_test_val_epoch_df.to_csv(idx_test_val_epoch_df_file, index=False, encoding="utf-8")
logger.info(f"local rank: {device_index}, test_[val]_dataset epoch: {idx_test_val_epoch}, save to {idx_test_val_epoch_df_file}")
else:
val_all_label_id = val_all_label_id[:]
val_all_labels = torch.cat(val_all_labels, dim=0)
val_all_preds = torch.cat(val_all_preds, dim=0)
lens = len(val_all_label_id)
idx_test_val_epoch_df = pd.DataFrame({
"test_type": ["val_dataset"] * lens,
"local_rank": [device_index] * lens,
"net_id": [net_id] * lens,
"net_type": [f"{net_id}_net"] * lens,
"epoch": [idx_test_val_epoch] * lens,
"label_id": val_all_label_id,
"z_index": val_all_z_index[:],
"rotate_count": val_all_rotate_count[:],
"label": val_all_labels.tolist(),
"pred": val_all_preds.tolist()
})
idx_test_val_epoch_df_file = f"local_rank_{device_index}_test_type_val_dataset_epoch_{idx_test_val_epoch}_net_{net_id}.csv"
idx_test_val_epoch_df_file = os.path.join(save_dir, idx_test_val_epoch_df_file)
idx_test_val_epoch_df.to_csv(idx_test_val_epoch_df_file, index=False, encoding="utf-8")
logger.info(f"local rank: {device_index}, test_[val]_dataset epoch: {idx_test_val_epoch}, save to {idx_test_val_epoch_df_file}")
logger.info(f"local rank: {device_index}, test_[val]_dataset finished")
# 评测测试集
torch.cuda.empty_cache()
test_test_dataset_flag = test_test_dataset_info_dict["test_test_dataset_flag"]
if test_test_dataset_flag:
test_test_epoch_list = test_test_dataset_info_dict["test_test_epoch_list"]
test_test_epoch_list = sorted(test_test_epoch_list)
logger.info(f"local rank: {device_index}, start test_[test]_dataset")
for idx_test_test_epoch in test_test_epoch_list:
idx_epoch_pt_file = f"train_epoch_{idx_test_test_epoch}_net_{net_id}.pt"
idx_epoch_pt_file = os.path.join(save_dir, idx_epoch_pt_file)
model = copy.deepcopy(init_model)
model.load_state_dict(torch.load(idx_epoch_pt_file, map_location="cpu"))
model = model.to(current_device)
logger.info(f"local rank: {device_index}, test_[test]_dataset, start test_[test]_dataset epoch: {idx_test_test_epoch}")
logger.info(f"local rank: {device_index}, test_[test]_dataset, after load epoch {idx_test_test_epoch}\nparams:\n{get_model_params(model)}")
model.eval()
test_all_label_id = []
test_all_label_id_2d = []
test_all_label_id_3d = []
test_all_z_index = []
test_all_z_index_2d = []
test_all_z_index_3d = []
test_all_rotate_count = []
test_all_rotate_count_2d = []
test_all_rotate_count_3d = []
test_all_labels = []
test_all_labels_2d = []
test_all_labels_3d = []
test_all_preds = []
test_all_preds_2d = []
test_all_preds_3d = []
model.eval()
with torch.no_grad():
logger.info(f"local rank: {device_index}, test_[test]_dataset, start test_[test]_dataset epoch: {idx_test_test_epoch}, test_data_loader: {len(test_data_loader)}")
for step, batch_data in enumerate(test_data_loader):
if net_id == "2d3d":
label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data
test_all_label_id_2d.extend(label_id_2d)
test_all_label_id_3d.extend(label_id_3d)
test_all_z_index_2d.extend(z_index_2d)
test_all_z_index_3d.extend(z_index_3d)
test_all_rotate_count_2d.extend(rotate_count_2d)
test_all_rotate_count_3d.extend(rotate_count_3d)
test_all_labels_2d.append(label_2d.cpu())
test_all_labels_3d.append(label_3d.cpu())
data_2d = data_2d.to(current_device)
data_3d = data_3d.to(current_device)
label_2d = label_2d.to(current_device)
label_3d = label_3d.to(current_device)
y_pred_2d, y_pred_3d = model(data_2d, data_3d)
test_all_preds_2d.append(y_pred_2d.detach().cpu())
test_all_preds_3d.append(y_pred_3d.detach().cpu())
else:
label_id, data, label, z_index, rotate_count = batch_data
test_all_label_id.extend(label_id)
test_all_z_index.extend(z_index)
test_all_rotate_count.extend(rotate_count)
test_all_labels.append(label.cpu())
data = data.to(current_device)
label = label.to(current_device)
y_pred = model(data)
test_all_preds.append(y_pred.detach().cpu())
logger.info(f"local rank: {device_index}, test_[test]_dataset, step: {step}, test_data_loader: {len(test_data_loader)}")
if net_id == "2d3d":
del model, data_2d, data_3d, label_2d, label_3d, y_pred_2d, y_pred_3d, z_index_2d, z_index_3d, rotate_count_2d, rotate_count_3d
else:
del model, data, label, y_pred, z_index, rotate_count
torch.cuda.empty_cache()
if net_id == "2d3d":
test_all_label_id_2d = test_all_label_id_2d[:]
test_all_label_id_3d = test_all_label_id_3d[:]
test_all_net_type_2d = ["2d_net"] * len(test_all_label_id_2d)
test_all_net_type_3d = ["3d_net"] * len(test_all_label_id_3d)
test_all_labels_2d = torch.cat(test_all_labels_2d, dim=0)
test_all_labels_3d = torch.cat(test_all_labels_3d, dim=0)
test_all_preds_2d = torch.cat(test_all_preds_2d, dim=0)
test_all_preds_3d = torch.cat(test_all_preds_3d, dim=0)
concat_label_id = test_all_label_id_2d + test_all_label_id_3d
concat_z_index = test_all_z_index_2d[:] + test_all_z_index_3d[:]
concat_rotate_count = test_all_rotate_count_2d[:] + test_all_rotate_count_3d[:]
concat_net_type = test_all_net_type_2d + test_all_net_type_3d
concat_label = torch.cat([test_all_labels_2d, test_all_labels_3d], dim=0)
concat_pred = torch.cat([test_all_preds_2d, test_all_preds_3d], dim=0)
logger.info(f"local rank: {device_index}, test_[test]_dataset, concat_label_id: {len(concat_label_id)}, concat_label: {concat_label.shape}, concat_pred: {concat_pred.shape}")
lens = len(concat_label_id)
idx_test_test_epoch_df = pd.DataFrame({
"test_type": ["test_dataset"] * lens,
"local_rank": [device_index] * lens,
"net_id": [net_id] * lens,
"net_type": concat_net_type,
"epoch": [idx_test_test_epoch] * lens,
"label_id": concat_label_id,
"z_index": concat_z_index,
"rotate_count": concat_rotate_count,
"label": concat_label.tolist(),
"pred": concat_pred.tolist()
})
idx_test_test_epoch_df_file = f"local_rank_{device_index}_test_type_test_dataset_epoch_{idx_test_test_epoch}_net_{net_id}.csv"
idx_test_test_epoch_df_file = os.path.join(save_dir, idx_test_test_epoch_df_file)
idx_test_test_epoch_df.to_csv(idx_test_test_epoch_df_file, index=False, encoding="utf-8")
logger.info(f"local rank: {device_index}, test_[test]_dataset epoch: {idx_test_test_epoch}, save to {idx_test_test_epoch_df_file}")
else:
test_all_label_id = test_all_label_id[:]
test_all_labels = torch.cat(test_all_labels, dim=0)
test_all_preds = torch.cat(test_all_preds, dim=0)
lens = len(test_all_label_id)
idx_test_test_epoch_df = pd.DataFrame({
"test_type": ["test_dataset"] * lens,
"local_rank": [device_index] * lens,
"net_id": [net_id] * lens,
"net_type": [f"{net_id}_net"] * lens,
"epoch": [idx_test_test_epoch] * lens,
"label_id": test_all_label_id,
"z_index": test_all_z_index[:],
"rotate_count": test_all_rotate_count[:],
"label": test_all_labels.tolist(),
"pred": test_all_preds.tolist()
})
idx_test_test_epoch_df_file = f"local_rank_{device_index}_test_type_test_dataset_epoch_{idx_test_test_epoch}_net_{net_id}.csv"
idx_test_test_epoch_df_file = os.path.join(save_dir, idx_test_test_epoch_df_file)
idx_test_test_epoch_df.to_csv(idx_test_test_epoch_df_file, index=False, encoding="utf-8")
logger.info(f"local rank: {device_index}, test_[test]_dataset epoch: {idx_test_test_epoch}, save to {idx_test_test_epoch_df_file}")
logger.info(f"local rank: {device_index}, test_[test]_dataset finished")
def start_test_ddp(model, config):
logger = config['logger']
net_id = config['net_id']
train_2d3d_data_2d_csv_file = config['train_2d3d_data_2d_csv_file']
val_2d3d_data_2d_csv_file = config['val_2d3d_data_2d_csv_file']
test_2d3d_data_2d_csv_file = config['test_2d3d_data_2d_csv_file']
train_2d3d_data_3d_csv_file = config['train_2d3d_data_3d_csv_file']
val_2d3d_data_3d_csv_file = config['val_2d3d_data_3d_csv_file']
test_2d3d_data_3d_csv_file = config['test_2d3d_data_3d_csv_file']
train_csv_file = config['train_csv_file']
val_csv_file = config['val_csv_file']
test_csv_file = config['test_csv_file']
device_index = config['device_index']
process_count = len(device_index)
save_dir = config['save_dir']
if net_id == "2d3d":
train_data_2d_df = pd.read_csv(train_2d3d_data_2d_csv_file, header=0, encoding="utf-8")
train_data_3d_df = pd.read_csv(train_2d3d_data_3d_csv_file, header=0, encoding="utf-8")
val_data_2d_df = pd.read_csv(val_2d3d_data_2d_csv_file, header=0, encoding="utf-8")
val_data_3d_df = pd.read_csv(val_2d3d_data_3d_csv_file, header=0, encoding="utf-8")
test_data_2d_df = pd.read_csv(test_2d3d_data_2d_csv_file, header=0, encoding="utf-8")
test_data_3d_df = pd.read_csv(test_2d3d_data_3d_csv_file, header=0, encoding="utf-8")
train_data_2d_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_train_data_2d_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
train_data_3d_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_train_data_3d_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
val_data_2d_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_val_data_2d_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
val_data_3d_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_val_data_3d_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
test_data_2d_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_test_data_2d_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
test_data_3d_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_test_data_3d_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
train_data_2d_df_list = np.array_split(train_data_2d_df, process_count)
val_data_2d_df_list = np.array_split(val_data_2d_df, process_count)
test_data_2d_df_list = np.array_split(test_data_2d_df, process_count)
train_data_3d_df_list = np.array_split(train_data_3d_df, process_count)
val_data_3d_df_list = np.array_split(val_data_3d_df, process_count)
test_data_3d_df_list = np.array_split(test_data_3d_df, process_count)
for idx_process in range(process_count):
train_data_2d_df_list[idx_process].to_csv(train_data_2d_file_list[idx_process], index=False, encoding="utf-8")
val_data_2d_df_list[idx_process].to_csv(val_data_2d_file_list[idx_process], index=False, encoding="utf-8")
test_data_2d_df_list[idx_process].to_csv(test_data_2d_file_list[idx_process], index=False, encoding="utf-8")
train_data_3d_df_list[idx_process].to_csv(train_data_3d_file_list[idx_process], index=False, encoding="utf-8")
val_data_3d_df_list[idx_process].to_csv(val_data_3d_file_list[idx_process], index=False, encoding="utf-8")
test_data_3d_df_list[idx_process].to_csv(test_data_3d_file_list[idx_process], index=False, encoding="utf-8")
else:
train_df = pd.read_csv(train_csv_file, header=0, encoding="utf-8")
val_df = pd.read_csv(val_csv_file, header=0, encoding="utf-8")
test_df = pd.read_csv(test_csv_file, header=0, encoding="utf-8")
train_data_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_train_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
val_data_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_val_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
test_data_file_list = [
os.path.join(save_dir, f"local_rank_{idx_process}_test_type_test_dataset_net_{net_id}.csv")
for idx_process in range(process_count)
]
train_df_list = np.array_split(train_df, process_count)
val_df_list = np.array_split(val_df, process_count)
test_df_list = np.array_split(test_df, process_count)
for idx_process in range(process_count):
train_df_list[idx_process].to_csv(train_data_file_list[idx_process], index=False, encoding="utf-8")
val_df_list[idx_process].to_csv(val_data_file_list[idx_process], index=False, encoding="utf-8")
test_df_list[idx_process].to_csv(test_data_file_list[idx_process], index=False, encoding="utf-8")
process_count = len(device_index)
process_list = []
if net_id == "2d3d":
train_data_file_list = [None] * process_count
val_data_file_list = [None] * process_count
test_data_file_list = [None] * process_count
else:
train_data_2d_df_list = [None] * process_count
val_data_2d_df_list = [None] * process_count
test_data_2d_df_list = [None] * process_count
train_data_3d_df_list = [None] * process_count
val_data_3d_df_list = [None] * process_count
test_data_3d_df_list = [None] * process_count
for idx in range(process_count):
idx_process = mp.Process(
target=test_on_single_gpu,
args=(
model,
config,
device_index[idx],
train_data_2d_file_list[idx],
val_data_2d_file_list[idx],
test_data_2d_file_list[idx],
train_data_3d_file_list[idx],
val_data_3d_file_list[idx],
test_data_3d_file_list[idx],
train_data_file_list[idx],
val_data_file_list[idx],
test_data_file_list[idx]
)
)
idx_process.start()
process_list.append(idx_process)
for idx_process in process_list:
idx_process.join()
logger.info("test_ddp finished")
def test_ddp(model, config):
try:
set_seed()
start_test_ddp(model, config)
except Exception as e:
print(f"error in test_ddp: {traceback.format_exc()}")
raise e
finally:
print("test_ddp finished")
def gather_test_result(config):
net_id = config['net_id']
node_id = config['node_id']
date_id = config['date_id']
save_dir = config['save_dir']
logger = config['logger']
local_rank_list = config['device_index']
test_train_all_epoch_file = None
test_val_all_epoch_file = None
test_test_all_epoch_file = None
def parse_label_id(df=None, label_id_column="label_id"):
if df is None:
return df
for idx in range(len(df)):
idx_label_id = df.loc[idx, label_id_column]
idx_node = idx_label_id.split("_")[0]
idx_label = idx_label_id.split("_")[1]
df.loc[idx, "label_id"] = f"{int(float(idx_node))}_{int(float(idx_label))}"
return df
# 训练集评测
test_train_dataset_info_dict = config['test_train_dataset_info_dict']
test_train_dataset_flag = test_train_dataset_info_dict['test_train_dataset_flag']
if test_train_dataset_flag:
print(f"start gather test_[train]_dataset")
test_train_epoch_list = test_train_dataset_info_dict['test_train_epoch_list']
test_train_df_list = []
for idx_test_train_epoch in test_train_epoch_list:
for idx_local_rank in local_rank_list:
idx_test_train_epoch_df_file = f"local_rank_{idx_local_rank}_test_type_train_dataset_epoch_{idx_test_train_epoch}_net_{net_id}.csv"
idx_test_train_epoch_df_file = os.path.join(config['save_dir'], idx_test_train_epoch_df_file)
idx_df = pd.read_csv(idx_test_train_epoch_df_file)
print(f"load idx_epoch_test_[train]_file: {idx_test_train_epoch_df_file}")
test_train_df_list.append(idx_df)
test_train_df = pd.concat(test_train_df_list, ignore_index=True)
test_train_df = parse_label_id(test_train_df)
test_train_df = test_train_df.drop_duplicates(subset=["epoch", "net_type", "label_id", "z_index", "rotate_count"], keep="first")
test_train_df = test_train_df.reset_index(drop=True)
test_train_df = test_train_df.sort_values(by=["test_type", "local_rank", "epoch", "net_type", "label_id", "z_index", "rotate_count"], ascending=[True, True, True, True, True, True, True])
test_train_all_epoch_file = f"gather_test_[train]_dataset_net_{net_id}_node_{node_id}_date_{date_id}.csv"
test_train_all_epoch_file = os.path.join(save_dir, test_train_all_epoch_file)
test_train_df.to_csv(test_train_all_epoch_file, index=False, encoding="utf-8")
print(f"gather test_[train]_dataset finished, save to {test_train_all_epoch_file}")
# 验证集评测
test_val_dataset_info_dict = config['test_val_dataset_info_dict']
test_val_dataset_flag = test_val_dataset_info_dict['test_val_dataset_flag']
if test_val_dataset_flag:
print(f"start gather test_[val]_dataset")
test_val_epoch_list = test_val_dataset_info_dict['test_val_epoch_list']
test_val_df_list = []
for idx_test_val_epoch in test_val_epoch_list:
for idx_local_rank in local_rank_list:
idx_test_val_epoch_df_file = f"local_rank_{idx_local_rank}_test_type_val_dataset_epoch_{idx_test_val_epoch}_net_{net_id}.csv"
idx_test_val_epoch_df_file = os.path.join(config['save_dir'], idx_test_val_epoch_df_file)
idx_df = pd.read_csv(idx_test_val_epoch_df_file)
print(f"load idx_epoch_test_[val]_file: {idx_test_val_epoch_df_file}")
test_val_df_list.append(idx_df)
test_val_df = pd.concat(test_val_df_list, ignore_index=True)
test_val_df = parse_label_id(test_val_df)
test_val_df = test_val_df.drop_duplicates(subset=["epoch", "net_type", "label_id", "z_index", "rotate_count"], keep="first")
test_val_df = test_val_df.reset_index(drop=True)
test_val_df = test_val_df.sort_values(by=["test_type", "local_rank", "epoch", "net_type", "label_id", "z_index", "rotate_count"], ascending=[True, True, True, True, True, True, True])
test_val_all_epoch_file = f"gather_test_[val]_dataset_net_{net_id}_node_{node_id}_date_{date_id}.csv"
test_val_all_epoch_file = os.path.join(save_dir, test_val_all_epoch_file)
test_val_df.to_csv(test_val_all_epoch_file, index=False, encoding="utf-8")
print(f"gather test_[val]_dataset finished, save to {test_val_all_epoch_file}")
# 测试集评测
test_test_dataset_info_dict = config['test_test_dataset_info_dict']
test_test_dataset_flag = test_test_dataset_info_dict['test_test_dataset_flag']
if test_test_dataset_flag:
print(f"start gather test_[test]_dataset")
test_test_epoch_list = test_test_dataset_info_dict['test_test_epoch_list']
test_test_df_list = []
for idx_test_test_epoch in test_test_epoch_list:
for idx_local_rank in local_rank_list:
idx_test_test_epoch_df_file = f"local_rank_{idx_local_rank}_test_type_test_dataset_epoch_{idx_test_test_epoch}_net_{net_id}.csv"
idx_test_test_epoch_df_file = os.path.join(config['save_dir'], idx_test_test_epoch_df_file)
idx_df = pd.read_csv(idx_test_test_epoch_df_file)
print(f"load idx_epoch_test_[test]_file: {idx_test_test_epoch_df_file}")
test_test_df_list.append(idx_df)
test_test_df = pd.concat(test_test_df_list, ignore_index=True)
test_test_df = parse_label_id(test_test_df)
test_test_df = test_test_df.drop_duplicates(subset=["epoch", "net_type", "label_id", "z_index", "rotate_count"], keep="first")
test_test_df = test_test_df.reset_index(drop=True)
test_test_df = test_test_df.sort_values(by=["test_type", "local_rank", "epoch", "net_type", "label_id", "z_index", "rotate_count"], ascending=[True, True, True, True, True, True, True])
test_test_all_epoch_file = f"gather_test_[test]_dataset_net_{net_id}_node_{node_id}_date_{date_id}.csv"
test_test_all_epoch_file = os.path.join(save_dir, test_test_all_epoch_file)
test_test_df.to_csv(test_test_all_epoch_file, index=False, encoding="utf-8")
print(f"gather test_[test]_dataset finished, save to {test_test_all_epoch_file}")
return test_train_all_epoch_file, test_val_all_epoch_file, test_test_all_epoch_file
def get_gather_result_train_dataset(gather_all_epoch_file, config, test_type="train_dataset", positive_threshold_list=[], negative_threshold_list=[], get_epoch_list=[]):
'''
筛选出数据集,预测错误的数据
'''
net_id = config['net_id']
save_dir = config['save_dir']
train_csv_file = config['train_csv_file']
train_df = pd.read_csv(train_csv_file)
label_id_patient_id_dict = {}
for idx in range(len(train_df)):
idx_node = train_df.loc[idx, "node_time"]
idx_label= train_df.loc[idx, "label_id"]
idx_patient_id = train_df.loc[idx, "patient_id"]
idx_label_id = f"{idx_node}_{idx_label}"
label_id_patient_id_dict[idx_label_id] = idx_patient_id
df = pd.read_csv(gather_all_epoch_file)
df = df[df['test_type'] == test_type].reset_index(drop=True)
print(f"df: {len(df)}")
epoch_list = df['epoch'].unique().tolist() if get_epoch_list == [] else get_epoch_list
epoch_list = sorted(epoch_list)
df_epoch_list = []
df_threshold_list = []
epoch_net_id_list = []
epoch_net_type_list = []
epoch_label_id_list = []
epoch_z_index_list = []
epoch_roate_count_list = []
epoch_label_list = []
epoch_pred_list = []
for idx_epoch in epoch_list:
df_epoch = df[df['epoch'] == idx_epoch].reset_index(drop=True)
if net_id == "2d3d":
for idx_positive_threshold, idx_negative_threshold in zip(positive_threshold_list, negative_threshold_list):
idx_cp_df_epoch = copy.deepcopy(df_epoch)
idx_df_2d = idx_cp_df_epoch[idx_cp_df_epoch['net_type'] == "2d_net"].reset_index(drop=True)
idx_df_3d = idx_cp_df_epoch[idx_cp_df_epoch['net_type'] == "3d_net"].reset_index(drop=True)
print(f"idx_df_2d: {len(idx_df_2d)}, idx_df_3d: {len(idx_df_3d)}")
for idx in range(len(idx_df_2d)):
idx_preds_2d = idx_df_2d.loc[idx, "pred"]
idx_label_2d = idx_df_2d.loc[idx, "label"]
idx_label_id = idx_df_2d.loc[idx, "label_id"]
idx_z_index = idx_df_2d.loc[idx, "z_index"]
idx_rotate_count = idx_df_2d.loc[idx, "rotate_count"]
idx_net_type = idx_df_2d.loc[idx, "net_type"]
if idx_label_2d == 1 and idx_preds_2d < idx_positive_threshold:
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_positive_threshold)
epoch_net_id_list.append(net_id)
epoch_net_type_list.append(idx_net_type)
epoch_label_id_list.append(idx_label_id)
epoch_z_index_list.append(idx_z_index)
epoch_roate_count_list.append(idx_rotate_count)
epoch_label_list.append(idx_label_2d)
epoch_pred_list.append(idx_preds_2d)
elif idx_label_2d == 0 and idx_preds_2d > idx_negative_threshold:
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_negative_threshold)
epoch_net_id_list.append(net_id)
epoch_net_type_list.append(idx_net_type)
epoch_label_id_list.append(idx_label_id)
epoch_z_index_list.append(idx_z_index)
epoch_roate_count_list.append(idx_rotate_count)
epoch_label_list.append(idx_label_2d)
epoch_pred_list.append(idx_preds_2d)
for idx in range(len(idx_df_3d)):
idx_preds_3d = idx_df_3d.loc[idx, "pred"]
idx_label_3d = idx_df_3d.loc[idx, "label"]
idx_label_id = idx_df_3d.loc[idx, "label_id"]
idx_z_index = idx_df_2d.loc[idx, "z_index"]
idx_rotate_count = idx_df_2d.loc[idx, "rotate_count"]
idx_net_type = idx_df_3d.loc[idx, "net_type"]
if idx_label_3d == 1 and idx_preds_3d < idx_positive_threshold:
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_positive_threshold)
epoch_net_id_list.append(net_id)
epoch_net_type_list.append(idx_net_type)
epoch_label_id_list.append(idx_label_id)
epoch_z_index_list.append(idx_z_index)
epoch_roate_count_list.append(idx_rotate_count)
epoch_label_list.append(idx_label_3d)
epoch_pred_list.append(idx_preds_3d)
elif idx_label_3d == 0 and idx_preds_3d > idx_negative_threshold:
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_negative_threshold)
epoch_net_id_list.append(net_id)
epoch_net_type_list.append(idx_net_type)
epoch_label_id_list.append(idx_label_id)
epoch_z_index_list.append(idx_z_index)
epoch_roate_count_list.append(idx_rotate_count)
epoch_label_list.append(idx_label_3d)
epoch_pred_list.append(idx_preds_3d)
else:
for idx_positive_threshold, idx_negative_threshold in zip(positive_threshold_list, negative_threshold_list):
idx_cp_df_epoch = copy.deepcopy(df_epoch)
for idx in range(len(idx_cp_df_epoch)):
idx_preds = idx_cp_df_epoch.loc[idx, "pred"]
idx_label = idx_cp_df_epoch.loc[idx, "label"]
idx_label_id = idx_cp_df_epoch.loc[idx, "label_id"]
idx_z_index = idx_df_2d.loc[idx, "z_index"]
idx_rotate_count = idx_df_2d.loc[idx, "rotate_count"]
idx_net_type = idx_cp_df_epoch.loc[idx, "net_type"]
if idx_label == 1 and idx_preds < idx_positive_threshold:
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_positive_threshold)
epoch_net_id_list.append(net_id)
epoch_net_type_list.append(idx_net_type)
epoch_label_id_list.append(idx_label_id)
epoch_z_index_list.append(idx_z_index)
epoch_roate_count_list.append(idx_rotate_count)
epoch_label_list.append(idx_label)
epoch_pred_list.append(idx_preds)
elif idx_label == 0 and idx_preds > idx_negative_threshold:
df_threshold_list.append(idx_positive_threshold)
epoch_net_id_list.append(net_id)
epoch_net_type_list.append(idx_net_type)
epoch_label_id_list.append(idx_label_id)
epoch_z_index_list.append(idx_z_index)
epoch_roate_count_list.append(idx_rotate_count)
epoch_label_list.append(idx_label)
epoch_pred_list.append(idx_preds)
epoch_error_df = pd.DataFrame({
"epoch": df_epoch_list,
"threshold": df_threshold_list,
"net_id": epoch_net_id_list,
"net_type": epoch_net_type_list,
"label_id": epoch_label_id_list,
"z_index": epoch_z_index_list,
"rotate_count": epoch_roate_count_list,
"patient_id": [label_id_patient_id_dict[idx_label_id] for idx_label_id in epoch_label_id_list],
"label": epoch_label_list,
"pred": epoch_pred_list,
})
epoch_error_file = f"{gather_all_epoch_file.replace('.csv', '_epoch_预测不在阈值区间的数据.csv')}"
epoch_error_file = os.path.join(save_dir, epoch_error_file)
epoch_error_df.to_csv(epoch_error_file, index=False, encoding="utf-8")
print(f"epoch_预测不在阈值区间的数据, save epoch_error_df to {epoch_error_file}")
def get_gather_result(gather_all_epoch_file, config, test_type="test_dataset", positive_threshold_list=[], step=0.01):
net_id = config['net_id']
save_dir = config['save_dir']
if net_id == "2d3d":
if test_type == "train_dataset":
origin_file_data_2d = config['train_2d3d_data_2d_csv_file']
origin_file_data_3d = config['train_2d3d_data_3d_csv_file']
print(f"origin_file_data_2d: {origin_file_data_2d}\norigin_file_data_3d: {origin_file_data_3d}")
elif test_type == "test_dataset":
origin_file_data_2d = config['test_2d3d_data_2d_csv_file']
origin_file_data_3d = config['test_2d3d_data_3d_csv_file']
else:
origin_file_data_2d = config['val_2d3d_data_2d_csv_file']
origin_file_data_3d = config['val_2d3d_data_3d_csv_file']
origin_file_data_2d = os.path.join(save_dir, origin_file_data_2d)
origin_file_data_3d = os.path.join(save_dir, origin_file_data_3d)
origin_df_data_2d = pd.read_csv(origin_file_data_2d)
origin_df_data_3d = pd.read_csv(origin_file_data_3d)
origin_df = pd.concat([origin_df_data_2d, origin_df_data_3d], ignore_index=True)
else:
if test_type == "train_dataset":
origin_file = config["train_csv_file"]
elif test_type == "test_dataset":
origin_file = config["test_csv_file"]
else:
origin_file = config["val_csv_file"]
origin_file = os.path.join(save_dir, origin_file)
origin_df = pd.read_csv(origin_file)
label_id_patient_id_dict = {}
for idx in range(len(origin_df)):
idx_node = origin_df.loc[idx, "node"]
idx_label_id = origin_df.loc[idx, "label_id"]
idx_patient_id = origin_df.loc[idx, "patient_id"]
idx_label_id = f"{int(float(idx_node))}_{int(float(idx_label_id))}"
label_id_patient_id_dict[idx_label_id] = idx_patient_id
df = pd.read_csv(gather_all_epoch_file)
# 统计训练集多个评测结果
df = df[df['test_type'] == test_type].reset_index(drop=True)
lens = len(df)
# 去除重复的
df = df.drop_duplicates(subset=['test_type', 'net_id', 'net_type', 'epoch', 'label_id', 'label', 'z_index', 'rotate_count'], keep="first")
print(f"去除重复的: {gather_all_epoch_file}\n{lens} --> {len(df)}")
# # 2d_net, 取平均
# aggregated_df = df.groupby(['test_type', 'local_rank', 'net_id', 'net_type', 'epoch', 'label_id', 'label'])['pred'].mean().reset_index()
# df = copy.deepcopy(aggregated_df)
# avg_2d_df_file = f"{gather_all_epoch_file.replace('.csv', '_2d_net_avg.csv')}"
# avg_2d_df_file = os.path.join(save_dir, avg_2d_df_file)
# df.to_csv(avg_2d_df_file, index=False, encoding="utf-8")
# print(f"save 2d_net_avg to {avg_2d_df_file}")
if positive_threshold_list == []:
positive_threshold_list = [round(i, 2) for i in np.arange(0.1, 1.0, 0.01)]
print(f"positive_threshold_list: {len(positive_threshold_list)}, {positive_threshold_list}")
epoch_list = df['epoch'].unique().tolist()
epoch_list = sorted(epoch_list)
print(f"epoch_list: {len(epoch_list)}, {epoch_list}")
df_epoch_list = []
df_threshold_list = []
epoch_accuracy_list = []
epoch_precision_list = []
epoch_recall_list = []
epoch_f1_list = []
epoch_error_list = []
epoch_positive_accuracy_list = []
epoch_negative_accuracy_list = []
epoch_positive_precision_list = []
epoch_positive_recall_list = []
epoch_positive_f1_list = []
epoch_negative_precision_list = []
epoch_negative_recall_list = []
epoch_negative_f1_list = []
epoch_accuracy_2d_list = []
epoch_precision_2d_list = []
epoch_recall_2d_list = []
epoch_f1_2d_list = []
epoch_error_2d_list = []
epoch_positive_accuracy_2d_list = []
epoch_negative_accuracy_2d_list = []
epoch_positive_precision_2d_list = []
epoch_positive_recall_2d_list = []
epoch_positive_f1_2d_list = []
epoch_negative_precision_2d_list = []
epoch_negative_recall_2d_list = []
epoch_negative_f1_2d_list = []
epoch_accuracy_3d_list = []
epoch_precision_3d_list = []
epoch_recall_3d_list = []
epoch_f1_3d_list = []
epoch_error_3d_list = []
epoch_positive_accuracy_3d_list = []
epoch_negative_accuracy_3d_list = []
epoch_positive_precision_3d_list = []
epoch_positive_recall_3d_list = []
epoch_positive_f1_3d_list = []
epoch_negative_precision_3d_list = []
epoch_negative_recall_3d_list = []
epoch_negative_f1_3d_list = []
for idx_epoch in epoch_list:
df_epoch = df[df['epoch'] == idx_epoch].reset_index(drop=True)
if net_id == "2d3d":
df_2d = df_epoch[df_epoch['net_type'] == "2d_net"].reset_index(drop=True)
df_3d = df_epoch[df_epoch['net_type'] == "3d_net"].reset_index(drop=True)
print(f"df_2d: {len(df_2d)}, df_3d: {len(df_3d)}")
epoch_df_2d_preds = torch.tensor(df_2d['pred'].tolist(), dtype=torch.float32)
epoch_df_2d_labels = torch.tensor(df_2d['label'].tolist())
epoch_df_3d_preds = torch.tensor(df_3d['pred'].tolist(), dtype=torch.float32)
epoch_df_3d_labels = torch.tensor(df_3d['label'].tolist())
epoch_df_2d_label_id = df_2d['label_id'].tolist()
epoch_df_3d_label_id = df_3d['label_id'].tolist()
epoch_df_2d_z_index = df_2d['z_index'].tolist()
epoch_df_3d_z_index = df_3d['z_index'].tolist()
epoch_df_2d_rotate_count = df_2d['rotate_count'].tolist()
epoch_df_3d_rotate_count = df_3d['rotate_count'].tolist()
for idx_positive_threshold in positive_threshold_list:
idx_epoch_df_2d_preds = copy.deepcopy(epoch_df_2d_preds)
idx_epoch_df_3d_preds = copy.deepcopy(epoch_df_3d_preds)
idx_epoch_df_2d_labels = copy.deepcopy(epoch_df_2d_labels)
idx_epoch_df_3d_labels = copy.deepcopy(epoch_df_3d_labels)
# 根据阈值将预测概率转换为二分类标签
idx_epoch_df_2d_preds_binary = (idx_epoch_df_2d_preds >= idx_positive_threshold).to(torch.int)
idx_epoch_df_3d_preds_binary = (idx_epoch_df_3d_preds >= idx_positive_threshold).to(torch.int)
# 将 tensor 转换为 numpy 数组以便于计算
idx_epoch_df_2d_preds_binary = idx_epoch_df_2d_preds_binary.numpy()
idx_epoch_df_3d_preds_binary = idx_epoch_df_3d_preds_binary.numpy()
idx_epoch_df_2d_labels = idx_epoch_df_2d_labels.numpy()
idx_epoch_df_3d_labels = idx_epoch_df_3d_labels.numpy()
# 整体准确率
acc_2d = accuracy_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary)
acc_3d = accuracy_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary,)
prec_2d = precision_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary)
rec_2d = recall_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary)
f1_2d = f1_score(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary)
prec_3d = precision_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary)
rec_3d = recall_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary)
f1_3d = f1_score(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary)
# 混淆矩阵
c_2d = confusion_matrix(idx_epoch_df_2d_labels, idx_epoch_df_2d_preds_binary)
c_3d = confusion_matrix(idx_epoch_df_3d_labels, idx_epoch_df_3d_preds_binary)
tp_2d, fp_2d, fn_2d, tn_2d = c_2d[1, 1], c_2d[0, 1], c_2d[1, 0], c_2d[0, 0]
tp_3d, fp_3d, fn_3d, tn_3d = c_3d[1, 1], c_3d[0, 1], c_3d[1, 0], c_3d[0, 0]
# 正样本指标
acc_pos_2d = (tp_2d) / (tp_2d + fn_2d) if (tp_2d + fn_2d) > 0 else 0
prec_pos_2d = tp_2d / (tp_2d + fp_2d) if (tp_2d + fp_2d) > 0 else 0
rec_pos_2d = tp_2d / (tp_2d + fn_2d) if (tp_2d + fn_2d) > 0 else 0
f1_pos_2d = 2 * (prec_pos_2d * rec_pos_2d) / (prec_pos_2d + rec_pos_2d) if (prec_pos_2d + rec_pos_2d) > 0 else 0
acc_pos_3d = (tp_3d) / (tp_3d + fn_3d) if (tp_3d + fn_3d) > 0 else 0
prec_pos_3d = tp_3d / (tp_3d + fp_3d) if (tp_3d + fp_3d) > 0 else 0
rec_pos_3d = tp_3d / (tp_3d + fn_3d) if (tp_3d + fn_3d) > 0 else 0
f1_pos_3d = 2 * (prec_pos_3d * rec_pos_3d) / (prec_pos_3d + rec_pos_3d) if (prec_pos_3d + rec_pos_3d) > 0 else 0
# 负样本指标
acc_neg_2d = (tn_2d) / (tn_2d + fp_2d) if (tn_2d + fp_2d) > 0 else 0
prec_neg_2d = tn_2d / (tn_2d + fn_2d) if (tn_2d + fn_2d) > 0 else 0
rec_neg_2d = tn_2d / (tn_2d + fp_2d) if (tn_2d + fp_2d) > 0 else 0
f1_neg_2d = 2 * (prec_neg_2d * rec_neg_2d) / (prec_neg_2d + rec_neg_2d) if (prec_neg_2d + rec_neg_2d) > 0 else 0
acc_neg_3d = (tn_3d) / (tn_3d + fp_3d) if (tn_3d + fp_3d) > 0 else 0
prec_neg_3d = tn_3d / (tn_3d + fn_3d) if (tn_3d + fn_3d) > 0 else 0
rec_neg_3d = tn_3d / (tn_3d + fp_3d) if (tn_3d + fp_3d) > 0 else 0
f1_neg_3d = 2 * (prec_neg_3d * rec_neg_3d) / (prec_neg_3d + rec_neg_3d) if (prec_neg_3d + rec_neg_3d) > 0 else 0
# 记录预测错误的数据
error_2d = [(label_id, z_index, rotate_count, label, pred_prob.numpy(), label_id_patient_id_dict[label_id]) for label_id, z_index, rotate_count, label, pred_prob, pred_binary in zip(epoch_df_2d_label_id, epoch_df_2d_z_index, epoch_df_2d_rotate_count, idx_epoch_df_2d_labels, idx_epoch_df_2d_preds, idx_epoch_df_2d_preds_binary) if label != pred_binary]
error_3d = [(label_id, z_index, rotate_count, label, pred_prob.numpy(), label_id_patient_id_dict[label_id]) for label_id, z_index, rotate_count, label, pred_prob, pred_binary in zip(epoch_df_3d_label_id, epoch_df_3d_z_index, epoch_df_3d_rotate_count, idx_epoch_df_3d_labels, idx_epoch_df_3d_preds, idx_epoch_df_3d_preds_binary) if label != pred_binary]
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_positive_threshold)
epoch_accuracy_2d_list.append(acc_2d)
epoch_precision_2d_list.append(prec_2d)
epoch_recall_2d_list.append(rec_2d)
epoch_f1_2d_list.append(f1_2d)
epoch_error_2d_list.append(error_2d)
epoch_accuracy_3d_list.append(acc_3d)
epoch_precision_3d_list.append(prec_3d)
epoch_recall_3d_list.append(rec_3d)
epoch_f1_3d_list.append(f1_3d)
epoch_error_3d_list.append(error_3d)
epoch_positive_accuracy_2d_list.append(acc_pos_2d)
epoch_negative_accuracy_2d_list.append(acc_neg_2d)
epoch_positive_accuracy_3d_list.append(acc_pos_3d)
epoch_negative_accuracy_3d_list.append(acc_neg_3d)
epoch_positive_precision_2d_list.append(prec_pos_2d)
epoch_negative_precision_2d_list.append(prec_neg_2d)
epoch_positive_precision_3d_list.append(prec_pos_3d)
epoch_negative_precision_3d_list.append(prec_neg_3d)
epoch_positive_recall_2d_list.append(rec_pos_2d)
epoch_negative_recall_2d_list.append(rec_neg_2d)
epoch_positive_recall_3d_list.append(rec_pos_3d)
epoch_negative_recall_3d_list.append(rec_neg_3d)
epoch_positive_f1_2d_list.append(f1_pos_2d)
epoch_negative_f1_2d_list.append(f1_neg_2d)
epoch_positive_f1_3d_list.append(f1_pos_3d)
epoch_negative_f1_3d_list.append(f1_neg_3d)
else:
epoch_df_preds = torch.tensor(df_epoch['pred'].tolist(), dtype=torch.float32)
epoch_df_labels = torch.tensor(df_epoch['label'].tolist())
epoch_df_label_id = df_epoch['label_id'].tolist()
epoch_df_z_index = df_epoch['z_index'].tolist()
epoch_df_rotate_count = df_epoch['rotate_count'].tolist()
for idx_positive_threshold in positive_threshold_list:
idx_epoch_df_preds = copy.deepcopy(epoch_df_preds)
idx_epoch_df_labels = copy.deepcopy(epoch_df_labels)
# 根据阈值将预测概率转换为二分类标签
idx_epoch_df_preds_binary = (idx_epoch_df_preds >= idx_positive_threshold).to(torch.int)
# 将 tensor 转换为 numpy 数组以便于计算
idx_epoch_df_preds_binary = idx_epoch_df_preds_binary.numpy()
idx_epoch_df_labels = idx_epoch_df_labels.numpy()
# 整体准确率
acc = accuracy_score(idx_epoch_df_labels, idx_epoch_df_preds_binary)
prec = precision_score(idx_epoch_df_labels, idx_epoch_df_preds_binary)
rec = recall_score(idx_epoch_df_labels, idx_epoch_df_preds_binary)
f1 = f1_score(idx_epoch_df_labels, idx_epoch_df_preds_binary)
# 混淆矩阵
c = confusion_matrix(idx_epoch_df_labels, idx_epoch_df_preds_binary)
tp, fp, fn, tn = c[1, 1], c[0, 1], c[1, 0], c[0, 0]
# 正样本指标
acc_pos = (tp) / (tp + fn) if (tp + fn) > 0 else 0
prec_pos = tp / (tp + fp) if (tp + fp) > 0 else 0
rec_pos = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_pos = 2 * (prec_pos * rec_pos) / (prec_pos + rec_pos) if (prec_pos + rec_pos) > 0 else 0
# 负样本指标
acc_neg = (tn) / (tn + fp) if (tn + fp) > 0 else 0
prec_neg = tn / (tn + fn) if (tn + fn) > 0 else 0
rec_neg = tn / (tn + fp) if (tn + fp) > 0 else 0
f1_neg = 2 * (prec_neg * rec_neg) / (prec_neg + rec_neg) if (prec_neg + rec_neg) > 0 else 0
# 记录预测错误的数据
error = [(label_id, z_index, rotate_count, label, pred_prob.numpy(), label_id_patient_id_dict[label_id]) for label_id, z_index, rotate_count, label, pred_prob, pred_binary in zip(epoch_df_label_id, epoch_df_z_index, epoch_df_rotate_count, idx_epoch_df_labels, idx_epoch_df_preds, idx_epoch_df_preds_binary) if label != pred_binary]
df_epoch_list.append(idx_epoch)
df_threshold_list.append(idx_positive_threshold)
epoch_accuracy_list.append(acc)
epoch_precision_list.append(prec)
epoch_recall_list.append(rec)
epoch_f1_list.append(f1)
epoch_error_list.append(error)
epoch_positive_accuracy_list.append(acc_pos)
epoch_negative_accuracy_list.append(acc_neg)
epoch_positive_precision_list.append(prec_pos)
epoch_negative_precision_list.append(prec_neg)
epoch_positive_recall_list.append(rec_pos)
epoch_negative_recall_list.append(rec_neg)
epoch_positive_f1_list.append(f1_pos)
epoch_negative_f1_list.append(f1_neg)
if net_id == "2d3d":
df_epoch_result = pd.DataFrame({
"test_type": [test_type] * len(df_epoch_list),
"epoch": df_epoch_list,
"threshold": df_threshold_list,
"acc_2d": epoch_accuracy_2d_list,
"prec_2d": epoch_precision_2d_list,
"rec_2d": epoch_recall_2d_list,
"f1_2d": epoch_f1_2d_list,
"acc_3d": epoch_accuracy_3d_list,
"prec_3d": epoch_precision_3d_list,
"rec_3d": epoch_recall_3d_list,
"f1_3d": epoch_f1_3d_list,
"acc_pos_2d": epoch_positive_accuracy_2d_list,
"prec_pos_2d": epoch_positive_precision_2d_list,
"rec_pos_2d": epoch_positive_recall_2d_list,
"f1_pos_2d": epoch_positive_f1_2d_list,
"acc_neg_2d": epoch_negative_accuracy_2d_list,
"prec_neg_2d": epoch_negative_precision_2d_list,
"rec_neg_2d": epoch_negative_recall_2d_list,
"f1_neg_2d": epoch_negative_f1_2d_list,
"acc_pos_3d": epoch_positive_accuracy_3d_list,
"prec_pos_3d": epoch_positive_precision_3d_list,
"rec_pos_3d": epoch_positive_recall_3d_list,
"f1_pos_3d": epoch_positive_f1_3d_list,
"acc_neg_3d": epoch_negative_accuracy_3d_list,
"prec_neg_3d": epoch_negative_precision_3d_list,
"rec_neg_3d": epoch_negative_recall_3d_list,
"f1_neg_3d": epoch_negative_f1_3d_list,
"error_data_2d": epoch_error_2d_list,
"error_data_3d": epoch_error_3d_list,
})
else:
df_epoch_result = pd.DataFrame({
"test_type": [test_type] * len(df_epoch_list),
"epoch": df_epoch_list,
"threshold": df_threshold_list,
"acc": epoch_accuracy_list,
"prec": epoch_precision_list,
"rec": epoch_recall_list,
"f1": epoch_f1_list,
"acc_pos": epoch_positive_accuracy_list,
"prec_pos": epoch_positive_precision_list,
"rec_pos": epoch_positive_recall_list,
"f1_pos": epoch_positive_f1_list,
"acc_neg": epoch_negative_accuracy_list,
"prec_neg": epoch_negative_precision_list,
"rec_neg": epoch_negative_recall_list,
"f1_neg": epoch_negative_f1_list,
"error_data": epoch_error_list,
})
df_epoch_result_file = f"{gather_all_epoch_file.replace('.csv', '_不同阈值_评测结果.csv')}"
df_epoch_result_file = os.path.join(save_dir, df_epoch_result_file)
df_epoch_result.to_csv(df_epoch_result_file, index=False, encoding="utf-8")
print(f"gather result finished, 不同阈值, save to {df_epoch_result_file}")
def get_test_config():
from pytorch_train.train_2d3d_config import node_net_train_file_dict
config = {}
net_id = "2d3d"
epochs = 1
node_id = "2021_2031"
device_index = [0]
print(f"node_id: {node_id}")
date_id = "20241217"
task_info = f"test_{net_id}_{node_id}_{date_id}_rotate_10_ddp_count_00000005"
train_log_dir = "/df_lung/ai-project/cls_train/log/train"
save_dir = "/df_lung/ai-project/cls_train/cls_ckpt"
save_folder = f"{net_id}_{node_id}_{date_id}"
log_file = os.path.join(train_log_dir, f"{task_info}.log")
save_dir = os.path.join(save_dir, save_folder)
train_data_dir = "/df_lung/cls_train_data/train_csv_data/"
train_csv_file = None
val_csv_file = None
test_csv_file = None
train_2d3d_data_2d_csv_file = None
val_2d3d_data_2d_csv_file = None
test_2d3d_data_2d_csv_file = None
train_2d3d_data_3d_csv_file = None
val_2d3d_data_3d_csv_file = None
test_2d3d_data_3d_csv_file = None
if net_id == "2d3d":
train_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_train_file"]
val_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_val_file"]
test_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_test_file"]
train_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_train_file"]
val_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_val_file"]
test_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_test_file"]
else:
train_csv_file = node_net_train_file_dict[(node_id, net_id)]["train_file"]
val_csv_file = node_net_train_file_dict[(node_id, net_id)]["val_file"]
test_csv_file = node_net_train_file_dict[(node_id, net_id)]["test_file"]
logger = get_logger(log_file)
if net_id == "2d3d":
model = Net2d3d()
batch_size = 32
elif net_id == "2d":
model = Net2d()
batch_size = 200
elif net_id == "3d":
model = Net3d()
batch_size = 6
elif net_id == "s3d":
model = NetS3d()
batch_size = 6
elif net_id == "d2d":
model = NetD2d()
batch_size = 1
if net_id not in ["s3d", "d2d"]:
init_modules(model)
# 创建文件夹
Path(train_log_dir).mkdir(parents=True, exist_ok=True)
Path(save_dir).mkdir(parents=True, exist_ok=True)
config['net_id'] = net_id
config['node_id'] = node_id
config['date_id'] = date_id
config['logger'] = logger
config['criterion'] = nn.BCELoss()
config['device_index'] = device_index
config['num_workers'] = 1
config['train_batch_size'] = batch_size
config['val_batch_size'] = batch_size
config['test_batch_size'] = batch_size
config['save_dir'] = save_dir
config['train_2d3d_data_2d_csv_file'] = train_2d3d_data_2d_csv_file
config['val_2d3d_data_2d_csv_file'] = val_2d3d_data_2d_csv_file
config['test_2d3d_data_2d_csv_file'] = test_2d3d_data_2d_csv_file
config['train_2d3d_data_3d_csv_file'] = train_2d3d_data_3d_csv_file
config['val_2d3d_data_3d_csv_file'] = val_2d3d_data_3d_csv_file
config['test_2d3d_data_3d_csv_file'] = test_2d3d_data_3d_csv_file
config['train_csv_file'] = train_csv_file
config['val_csv_file'] = val_csv_file
config['test_csv_file'] = test_csv_file
config['num_epochs'] = epochs
config['test_train_dataset_info_dict'] = {
"test_train_dataset_flag": True,
"test_train_epoch_list": [
1
],
"positive_threshold_list": [0.7],
"negative_threshold_list": [0.3]
}
config['test_val_dataset_info_dict'] = {
"test_val_dataset_flag": True,
"test_val_epoch_list": [
idx_test_val_epoch+1 for idx_test_val_epoch in range(config['num_epochs'])
]
}
config['test_test_dataset_info_dict'] = {
"test_test_dataset_flag": True,
"test_test_epoch_list": [
idx_test_test_epoch+1 for idx_test_test_epoch in range(config['num_epochs'])
]
}
# train_data_loader = DataLoader(train_dataset, batch_size=config['train_batch_size'],
# num_workers=config['num_workers'], shuffle=True)
# val_data_loader = DataLoader(val_dataset, batch_size=config['val_batch_size'],
# num_workers=config['num_workers'], shuffle=True)
return model, config
# if __name__ == "__main__":
# try:
# model, config = get_test_config()
# test_ddp(model, config)
# except Exception as e:
# print(f"error info: {traceback.format_exc()}")
# raise e
# finally:
# print("多卡评测 finished")
# python pytorch_train/_test_2d3d.py
if __name__ == "__main__":
_, config = get_test_config()
# 所有epoch的评测结果
config['test_train_dataset_info_dict']['test_train_dataset_flag'] = False
config['test_val_dataset_info_dict']['test_val_dataset_flag'] = True
config['test_test_dataset_info_dict']['test_test_dataset_flag'] = True
# test_train_all_epoch_file, test_val_all_epoch_file, test_test_all_epoch_file = gather_test_result(config)
# print(f"test_train_all_epoch_file: {test_train_all_epoch_file}\ntest_val_all_epoch_file: {test_val_all_epoch_file}\ntest_test_all_epoch_file: {test_test_all_epoch_file}")
# # # # /df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241124/gather_test_[train]_dataset_net_2d3d_node_2021_2031_date_20241124.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241124/gather_test_[val]_dataset_net_2d3d_node_2021_2031_date_20241124.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241124/gather_test_[test]_dataset_net_2d3d_node_2021_2031_date_20241124.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2021_2031_20241130/gather_test_[val]_dataset_net_3d_node_2021_2031_date_20241130.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2021_2031_20241130/gather_test_[test]_dataset_net_3d_node_2021_2031_date_20241130.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/encoder3d_2021_2031_20241131/gather_test_[val]_dataset_net_encoder3d_node_2021_2031_date_20241131.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/encoder3d_2021_2031_20241131/gather_test_[test]_dataset_net_encoder3d_node_2021_2031_date_20241131.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/2d_2021_2031_20241201/gather_test_[val]_dataset_net_2d_node_2021_2031_date_20241201.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/2d_2021_2031_20241201/gather_test_[test]_dataset_net_2d_node_2021_2031_date_20241201.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[val]_dataset_net_3d_node_2041_2031_date_20241201.csv
# # # # /df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[test]_dataset_net_3d_node_2041_2031_date_20241201.csv
# 统计评测结果
test_train_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241217/gather_test_[train]_dataset_net_2d3d_node_2021_2031_date_20241217.csv"
test_val_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241217/gather_test_[val]_dataset_net_2d3d_node_2021_2031_date_20241217.csv"
test_test_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241217/gather_test_[test]_dataset_net_2d3d_node_2021_2031_date_20241217.csv"
# # # test_train_all_epoch_file = None
# # # test_val_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[val]_dataset_net_3d_node_2041_2031_date_20241201.csv"
# # # test_test_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/3d_2041_2031_20241201/gather_test_[test]_dataset_net_3d_node_2041_2031_date_20241201.csv"
# # # positive_threshold_list = config['test_train_dataset_info_dict']['positive_threshold_list']
# # # negative_threshold_list = config['test_train_dataset_info_dict']['negative_threshold_list']
# # # assert len(positive_threshold_list) == len(negative_threshold_list)
# test_test_all_epoch_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241207/gather_test_[test]_dataset_net_2d3d_node_2021_2031_date_20241207.csv"
# # # # 查看指定epoch的不在阈值区间的数据
# # # get_gather_result_train_dataset(
# # # test_train_all_epoch_file,
# # # config,
# # # test_type="train_dataset",
# # # positive_threshold_list=positive_threshold_list,
# # # negative_threshold_list=negative_threshold_list
# # # )
# 不同阈值的评测指标
get_gather_result(test_train_all_epoch_file, config, test_type="train_dataset")
get_gather_result(test_val_all_epoch_file, config, test_type="val_dataset")
get_gather_result(test_test_all_epoch_file, config, test_type="test_dataset")
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_modules(model):
for name, param in model.named_parameters():
if 'weight' in name:
if isinstance(param, torch.nn.Parameter):
nn.init.kaiming_normal_(param.data, mode='fan_out', nonlinearity='relu')
elif 'bias' in name:
if isinstance(param, torch.nn.Parameter):
nn.init.constant_(param.data, 0.04)
def create_conv_block_3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=True, bn=True, activation=True):
layers = []
layers.append(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias))
if bn:
layers.append(nn.InstanceNorm3d(out_planes))
if activation:
layers.append(nn.LeakyReLU(inplace=True))
return nn.Sequential(*layers)
def create_conv_block_k1_3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block_3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k2_3d(in_planes, out_planes, kernel_size=2, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block_3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k3_3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block_3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
class RfbfBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0):
super(RfbfBlock3d, self).__init__()
inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 1)
self.groups = groups
self.group_num = inter_planes // groups
self.droprate = droprate
self.rotation_attn = nn.Sequential(
nn.AdaptiveAvgPool3d(1),
nn.Conv3d(in_planes, max(inter_planes//2, 1), 1),
nn.LeakyReLU(inplace=True),
nn.Conv3d(max(inter_planes//2, 1), inter_planes, 1),
nn.Sigmoid()
)
self.branch1 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,3,3), padding=(1,1,1),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,3,3), padding=(1,1,1),
dilation=(1,1,1), groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,5,5), padding=(1,2,2),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,5,5), padding=(2,4,4),
dilation=(2,2,2), groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,7,7), padding=(1,3,3),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,7,7), padding=(3,3,3),
dilation=(3,1,1), groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,5,5), padding=(1,2,2),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(1,3,3), padding=(0,1,1),
dilation=(1,1,1), groups=groups)
)
self.branch5 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(5,5,5), padding=(2,2,2),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,3,3), padding=(2,2,2),
dilation=(2,2,2), groups=groups)
)
self.branch6 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(3,9,9), padding=(1,4,4),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(1,5,5), padding=(0,2,2),
dilation=(1,1,1), groups=groups)
)
self.branch7 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(5,11,11), padding=(2,5,5),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,5,5), padding=(1,4,4),
dilation=(1,2,2), groups=groups)
)
self.branch8 = nn.Sequential(
create_conv_block_3d(in_planes, inter_planes, kernel_size=(7,7,7), padding=(3,3,3),
stride=stride, groups=groups),
create_conv_block_3d(inter_planes, inter_planes, kernel_size=(3,3,3), padding=(2,2,2),
dilation=(2,2,2), groups=groups)
)
def forward(self, x):
attn = self.rotation_attn(x)
x = x * attn
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x6 = self.branch6(x)
x7 = self.branch7(x)
x8 = self.branch8(x)
'''
RfbfBlock3d, x1.shape: torch.Size([1, 1, 48, 256, 256]), x2.shape: torch.Size([1, 1, 48, 256, 256]), x3.shape: torch.Size([1, 1, 48, 250, 250]), x4.shape: torch.Size([1, 1, 48, 256, 256]), x5.shape: torch.Size([1, 1, 48, 256, 256]), x6.shape: torch.Size([1, 1, 48, 256, 256]), x7.shape: torch.Size([1, 1, 48, 256, 256]), x8.shape: torch.Size([1, 1, 48, 256, 256])
'''
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num],
x5[:, group * self.group_num:(group + 1) * self.group_num],
x6[:, group * self.group_num:(group + 1) * self.group_num],
x7[:, group * self.group_num:(group + 1) * self.group_num],
x8[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
return out
class RfbeBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1):
super(RfbeBlock3d, self).__init__()
inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 2)
self.groups = groups
self.group_num = 2 * inter_planes // groups
self.branch1 = nn.Sequential(
create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block_3d(in_planes, 2 * inter_planes, kernel_size=7, stride=stride, padding=3, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=4, dilation=4, groups=groups)
)
self.concat_conv = create_conv_block_k1_3d(8 * inter_planes, out_planes, groups=groups, activation=False)
self.shortcut = create_conv_block_k1_3d(in_planes, out_planes, stride=stride, groups=groups, activation=False)
def forward(self, x):
identity = self.shortcut(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
out = self.concat_conv(out)
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class RfbBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1):
super(RfbBlock3d, self).__init__()
inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 2)
self.groups = groups
self.group_num = 2 * inter_planes // groups
self.branch1 = nn.Sequential(
create_conv_block_k1_3d(in_planes, 2 * inter_planes, stride=stride, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=1, dilation=1, groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block_k1_3d(in_planes, inter_planes, groups=groups),
create_conv_block_3d(inter_planes, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=2, dilation=2, groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block_k1_3d(in_planes, inter_planes, groups=groups),
create_conv_block_3d(inter_planes, 2 * inter_planes, kernel_size=5, stride=stride, padding=2, groups=groups),
create_conv_block_k3_3d(2 * inter_planes, 2 * inter_planes, padding=3, dilation=3, groups=groups)
)
self.concat_conv = create_conv_block_k1_3d(6 * inter_planes, out_planes, groups=groups, activation=False)
self.shortcut = create_conv_block_k1_3d(in_planes, out_planes, stride=stride, groups=groups, activation=False)
def forward(self, x):
identity = self.shortcut(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
out = self.concat_conv(out)
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class ResBasicBlock3d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False):
super(ResBasicBlock3d, self).__init__()
self.conv1 = create_conv_block_k3_3d(in_planes, out_planes, stride=stride, groups=groups)
self.conv2 = create_conv_block_k3_3d(out_planes, out_planes, groups=groups, activation=False)
self.shortcut = None
if stride != 1 or in_planes != out_planes:
self.shortcut = create_conv_block_k1_3d(in_planes, out_planes, stride=stride, groups=groups, activation=False)
self.droprate = droprate
self.se = se
if se:
self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4)
self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes)
def forward(self, x):
identity = self.shortcut(x) if self.shortcut is not None else x
out = self.conv1(x)
if self.droprate > 0:
out = F.dropout3d(out, p=self.droprate, training=self.training)
out = self.conv2(out)
if self.se:
original_out = out
out = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.flatten(out, 1)
out = self.fc1(out)
out = F.leaky_relu(out, inplace=True)
out = self.fc2(out)
out = out.sigmoid()
out = out.view(out.size(0), out.size(1), 1, 1, 1)
out = out * original_out
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class UpBlock3d(nn.Module):
def __init__(self, in_planes1, in_planes2, out_planes, groups=1, scale_factor=2):
super(UpBlock3d, self).__init__()
self.scale_factor = scale_factor
self.conv1 = create_conv_block_k3_3d(in_planes1, out_planes, groups=groups, activation=False)
self.conv2 = create_conv_block_k3_3d(in_planes2, out_planes, groups=groups, activation=False)
def forward(self, x1, x2):
if self.scale_factor != 1 and self.scale_factor != (1, 1, 1):
x1 = F.interpolate(x1, scale_factor=self.scale_factor, mode='nearest')
out1 = self.conv1(x1)
out2 = self.conv2(x2)
out = out1 + out2
out = F.leaky_relu(out, inplace=True)
return out
def create_conv_block_2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=True, bn=True, activation=True):
layers = []
layers.append(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias))
if bn:
layers.append(nn.InstanceNorm2d(out_planes))
if activation:
layers.append(nn.LeakyReLU(inplace=True))
return nn.Sequential(*layers)
def create_conv_block_k1_2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block_2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k2_2d(in_planes, out_planes, kernel_size=2, stride=1, padding=0,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block_2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
def create_conv_block_k3_2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, bias=False, bn=True, activation=True):
return create_conv_block_2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, bn=bn, activation=activation)
class RfbfBlock2d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0):
super(RfbfBlock2d, self).__init__()
inter_planes = max(int(torch.ceil(torch.tensor(out_planes / 8))), 1)
self.groups = groups
self.group_num = inter_planes // groups
self.droprate = droprate
self.rotation_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_planes, max(inter_planes//2, 1), 1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(max(inter_planes//2, 1), inter_planes, 1),
nn.Sigmoid()
)
self.branch1 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(1, 1),
dilation=(1, 1), groups=groups)
)
self.branch2 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(2, 2),
dilation=(2, 2), groups=groups)
)
self.branch3 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(3, 3),
dilation=(3, 3), groups=groups)
)
self.branch4 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(5, 5), padding=(2, 2),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(5, 5), padding=(4, 4),
dilation=(2, 2), groups=groups)
)
self.branch5 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(5, 5), padding=(4, 4),
dilation=(2, 2), groups=groups)
)
self.branch6 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(9, 9), padding=(4, 4),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(3, 3), padding=(2, 2),
dilation=(2, 2), groups=groups)
)
self.branch7 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(7, 7), padding=(3, 3),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(7, 7), padding=(6, 6),
dilation=(2, 2), groups=groups)
)
self.branch8 = nn.Sequential(
create_conv_block_2d(in_planes, inter_planes, kernel_size=(11, 11), padding=(5, 5),
stride=stride, groups=groups),
create_conv_block_2d(inter_planes, inter_planes, kernel_size=(5, 5), padding=(4, 4),
dilation=(2, 2), groups=groups)
)
def forward(self, x):
attn = self.rotation_attn(x)
x = x * attn
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x4 = self.branch4(x)
x5 = self.branch5(x)
x6 = self.branch6(x)
x7 = self.branch7(x)
x8 = self.branch8(x)
if self.groups == 1:
out = torch.cat((x1, x2, x3, x4, x5, x6, x7, x8), 1)
else:
for group in range(self.groups):
group_out = torch.cat((x1[:, group * self.group_num:(group + 1) * self.group_num],
x2[:, group * self.group_num:(group + 1) * self.group_num],
x3[:, group * self.group_num:(group + 1) * self.group_num],
x4[:, group * self.group_num:(group + 1) * self.group_num],
x5[:, group * self.group_num:(group + 1) * self.group_num],
x6[:, group * self.group_num:(group + 1) * self.group_num],
x7[:, group * self.group_num:(group + 1) * self.group_num],
x8[:, group * self.group_num:(group + 1) * self.group_num]), 1)
if group == 0:
out = group_out
else:
out = torch.cat((out, group_out), 1)
if self.droprate > 0:
out = F.dropout2d(out, p=self.droprate, training=self.training)
return out
class ResBasicBlock2d(nn.Module):
def __init__(self, in_planes, out_planes, stride=1, groups=1, droprate=0, se=False):
super(ResBasicBlock2d, self).__init__()
self.conv1 = create_conv_block_k3_2d(in_planes, out_planes, stride=stride, groups=groups)
self.conv2 = create_conv_block_k3_2d(out_planes, out_planes, groups=groups, activation=False)
self.shortcut = None
if stride != 1 or in_planes != out_planes:
self.shortcut = create_conv_block_k1_2d(in_planes, out_planes, stride=stride, groups=groups, activation=False)
self.droprate = droprate
self.se = se
if se:
self.fc1 = nn.Linear(in_features=out_planes, out_features=out_planes // 4)
self.fc2 = nn.Linear(in_features=out_planes // 4, out_features=out_planes)
def forward(self, x):
identity = self.shortcut(x) if self.shortcut is not None else x
out = self.conv1(x)
if self.droprate > 0:
out = F.dropout2d(out, p=self.droprate, training=self.training)
out = self.conv2(out)
if self.se:
original_out = out
out = F.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.fc1(out)
out = F.leaky_relu(out, inplace=True)
out = self.fc2(out)
out = out.sigmoid()
#这里需要测试一下
out = out.view(out.size(0), out.size(1), 1, 1)
out = out * original_out
out = out + identity
out = F.leaky_relu(out, inplace=True)
return out
class FPN3d(nn.Module):
def __init__(self, n_channels, n_base_filters, groups=1):
super(FPN3d, self).__init__()
self.down_level1 = nn.Sequential(
RfbfBlock3d(n_channels, n_base_filters, groups=groups)
)
self.down_level2 = nn.Sequential(
ResBasicBlock3d(1 * n_base_filters, 2 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
)
self.down_level3 = nn.Sequential(
ResBasicBlock3d(2 * n_base_filters, 4 * n_base_filters, stride=(1, 2, 2), groups=groups),
ResBasicBlock3d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
)
self.down_level4 = nn.Sequential(
ResBasicBlock3d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
)
self.down_level5 = nn.Sequential(
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock3d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock3d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
)
self.down_level6 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level7 = nn.Sequential(
ResBasicBlock3d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level8 = nn.Sequential(
create_conv_block_k3_3d(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
)
def forward(self, x):
down_out1 = self.down_level1(x)
down_out2 = self.down_level2(down_out1)
down_out3 = self.down_level3(down_out2)
down_out4 = self.down_level4(down_out3)
down_out5 = self.down_level5(down_out4)
down_out6 = self.down_level6(down_out5)
down_out7 = self.down_level7(down_out6)
down_out8 = self.down_level8(down_out7)
return down_out8
class Net3d(nn.Module):
def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
super(Net3d, self).__init__()
self.fpn = FPN3d(n_channels, n_base_filters)
self.channel_attn = nn.Sequential(
nn.Linear(32 * n_base_filters, 32 * n_base_filters // 4),
nn.LeakyReLU(inplace=True),
nn.Linear(32 * n_base_filters // 4, 32 * n_base_filters),
nn.Sigmoid()
)
self.diff_classifier = nn.Linear(32 * n_base_filters, n_diff_classes)
def forward(self, data):
out = self.fpn(data)
#改成最大池化
out_max = F.adaptive_max_pool3d(out, (1, 1, 1))
out_avg = F.adaptive_avg_pool3d(out, (1, 1, 1))
out = torch.cat([torch.flatten(out_avg, 1), torch.flatten(out_max, 1)], 1)
attn = self.channel_attn(out)
out = out * attn
diff_output = self.diff_classifier(out)
diff_output = F.sigmoid(diff_output)
diff_output = diff_output.squeeze(1)
return diff_output
def net_test_3d():
pass
class FPN2d(nn.Module):
def __init__(self, n_channels, n_base_filters, groups=1):
super(FPN2d, self).__init__()
self.down_level1 = nn.Sequential(
RfbfBlock2d(n_channels, n_base_filters, groups=groups)
)
self.down_level2 = nn.Sequential(
ResBasicBlock2d(1 * n_base_filters, 2 * n_base_filters, stride=(2, 2), groups=groups),
ResBasicBlock2d(2 * n_base_filters, 2 * n_base_filters, groups=groups)
)
self.down_level3 = nn.Sequential(
ResBasicBlock2d(2 * n_base_filters, 4 * n_base_filters, stride=(2, 2), groups=groups),
ResBasicBlock2d(4 * n_base_filters, 4 * n_base_filters, groups=groups)
)
self.down_level4 = nn.Sequential(
ResBasicBlock2d(4 * n_base_filters, 8 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock2d(8 * n_base_filters, 8 * n_base_filters, groups=groups, se=True)
)
self.down_level5 = nn.Sequential(
ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True),
ResBasicBlock2d(16 * n_base_filters, 8 * n_base_filters, groups=groups, se=True),
ResBasicBlock2d(8 * n_base_filters, 16 * n_base_filters, groups=groups, se=True)
)
self.down_level6 = nn.Sequential(
ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level7 = nn.Sequential(
ResBasicBlock2d(16 * n_base_filters, 16 * n_base_filters, stride=2, groups=groups, se=True)
)
self.down_level8 = nn.Sequential(
create_conv_block_k3_2d(16 * n_base_filters, 16 * n_base_filters, padding=0, bn=False)
)
def forward(self, x):
down_out1 = self.down_level1(x)
down_out2 = self.down_level2(down_out1)
down_out3 = self.down_level3(down_out2)
down_out4 = self.down_level4(down_out3)
down_out5 = self.down_level5(down_out4)
down_out6 = self.down_level6(down_out5)
down_out7 = self.down_level7(down_out6)
down_out8 = self.down_level8(down_out7)
return down_out8
class Net2d(nn.Module):
def __init__(self, n_channels=1, n_diff_classes=1, n_base_filters=8):
super(Net2d, self).__init__()
self.fpn = FPN2d(n_channels, n_base_filters)
self.channel_attn = nn.Sequential(
nn.Linear(32 * n_base_filters, 32 * n_base_filters // 4),
nn.LeakyReLU(inplace=True),
nn.Linear(32 * n_base_filters // 4, 32 * n_base_filters),
nn.Sigmoid()
)
self.diff_classifier = nn.Linear(32 * n_base_filters, n_diff_classes)
def forward(self, data):
batch_size = data.shape[0]
out_feat = self.fpn(data)
out_avg = F.adaptive_avg_pool2d(out_feat, (1, 1))
out_max = F.adaptive_max_pool2d(out_feat, (1, 1))
out_flat = torch.cat([torch.flatten(out_avg, 1), torch.flatten(out_max, 1)], 1)
out_attn = self.channel_attn(out_flat)
out = out_flat * out_attn
diff_output = self.diff_classifier(out)
diff_output = F.sigmoid(diff_output)
diff_output = diff_output.squeeze(1)
return diff_output
def net_test_2d():
pass
class Net2d3d(nn.Module):
def __init__(self):
super(Net2d3d, self).__init__()
self.net2d = Net2d()
self.net3d = Net3d()
def forward(self, data_2d, data_3d):
out2d = self.net2d(data_2d)
out3d = self.net3d(data_3d)
return out2d, out3d
def net_test_2d3d():
pass
def test_initialization():
pass
if __name__ == '__main__':
test_initialization()
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any, Dict, List, Tuple, Type, Optional
from functools import partial
import numpy as np
import math
from transformers import AutoImageProcessor, AutoModel
# from ._registry import register_model
# from ._pretrain import load_pretrained_weights
def window_partition3D(x: torch.Tensor,
window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, D, H, W, C = x.shape
pad_d = (window_size - D % window_size) % window_size
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0 or pad_d > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_d))
Hp, Wp, Dp = H + pad_h, W + pad_w, D + pad_d
x = x.view(B, Dp // window_size, window_size, Hp // window_size, window_size,
Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size,
window_size, C)
return windows, (Dp, Hp, Wp)
def window_unpartition3D(windows: torch.Tensor, window_size: int, pad_dhw: Tuple[int, int, int],
dhw: Tuple[int, int, int]) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Dp, Hp, Wp = pad_dhw
D, H, W = dhw
B = windows.shape[0] // (Dp * Hp * Wp // window_size // window_size // window_size)
x = windows.view(B, Dp // window_size, Hp // window_size, Wp // window_size, window_size,
window_size, window_size, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Dp, Hp, Wp, -1)
if Hp > H or Wp > W or Dp > D:
x = x[:, :D, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_d: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int, int],
k_size: Tuple[int, int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_d, q_h, q_w = q_size
k_d, k_h, k_w = k_size
Rd = get_rel_pos(q_d, k_d, rel_pos_d)
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_d, q_h, q_w, dim)
rel_d = torch.einsum("bdhwc,dkc->bdhwk", r_q, Rd)
rel_h = torch.einsum("bdhwc,hkc->bdhwk", r_q, Rh)
rel_w = torch.einsum("bdhwc,wkc->bdhwk", r_q, Rw)
attn = (attn.view(B, q_d, q_h, q_w, k_d, k_h, k_w) + rel_d[:, :, :, :, None, None] +
rel_h[:, :, :, None, :, None] + rel_w[:, :, :, None, None, :]).view(
B, q_d * q_h * q_w, k_d * k_h * k_w)
return attn
class PatchEmbed3D(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16, 16),
stride: Tuple[int, int] = (16, 16, 16),
padding: Tuple[int, int] = (0, 0, 0),
in_chans: int = 1,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv3d(in_chans,
embed_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C X Y Z -> B X Y Z C
x = x.permute(0, 2, 3, 4, 1)
return x
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size
is not None), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_d = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[2] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, D, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, D * H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, D * H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_d, self.rel_pos_h, self.rel_pos_w,
(D, H, W), (D, H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, D, H, W, -1).permute(0, 2, 3, 4, 1,
5).reshape(B, D, H, W, -1)
x = self.proj(x)
return x
class LayerNorm3d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
return x
class Block3D(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
D, H, W = x.shape[1], x.shape[2], x.shape[3]
x, pad_dhw = window_partition3D(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition3D(x, self.window_size, pad_dhw, (D, H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT3D(nn.Module):
def __init__(
self,
img_size: int = 256,
patch_size: int = 16,
in_chans: int = 1,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed3D(
kernel_size=(patch_size, patch_size, patch_size),
stride=(patch_size, patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size,
img_size // patch_size, embed_dim))
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block3D(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size,
img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv3d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
# nn.LayerNorm(out_chans),
LayerNorm3d(out_chans),
nn.Conv3d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm3d(out_chans),
# nn.LayerNorm(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input_size = [1,1,256,256,256]
# import IPython; IPython.embed()
x = self.patch_embed(x)
# x = [1,16,16,16,768]
# import pdb; pdb.set_trace()
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
# x = [1,16,16,16,768]
# x = self.neck(x.permute(0, 4, 1, 2, 3))
# output_size = [1,256,16,16,16]
return x
class TwoWayTransformer3D(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock3D(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
))
self.final_attn_token_to_image = DownscaleAttention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: torch.Tensor,
image_pe: torch.Tensor,
point_embedding: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, x, y, z = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class DownscaleAttention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
class TwoWayAttentionBlock3D(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = DownscaleAttention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = DownscaleAttention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = DownscaleAttention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor,
key_pe: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class MaskDecoder3D(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
# transformer: nn.Module ,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
# self.transformer = transformer
self.transformer = TwoWayTransformer3D(
depth=2,
embedding_dim=self.transformer_dim,
mlp_dim=2048,
num_heads=8,
)
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm3d(transformer_dim // 4),
activation(),
nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2,
stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList([
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
])
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens,
iou_head_depth)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
src = image_embeddings
src = src + dense_prompt_embeddings
if image_pe.shape[0] != tokens.shape[0]:
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
else:
pos_src = image_pe
b, c, x, y, z = src.shape
# Run the transformer
# import IPython; IPython.embed()
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, x, y, z)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, x, y, z = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, x * y * z)).view(b, -1, x, y, z)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
class PositionEmbeddingRandom3D(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((3, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords), torch.sin(coords)], dim=-1)
def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
x, y, z = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((x, y, z), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
z_embed = grid.cumsum(dim=2) - 0.5
y_embed = y_embed / y
x_embed = x_embed / x
z_embed = z_embed / z
pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
return pe.permute(3, 0, 1, 2) # C x X x Y x Z
def forward_with_coords(self, coords_input: torch.Tensor,
image_size: Tuple[int, int, int]) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[0]
coords[:, :, 1] = coords[:, :, 1] / image_size[1]
coords[:, :, 2] = coords[:, :, 2] / image_size[2]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class PromptEncoder3D(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int, int],
input_image_size: Tuple[int, int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom3D(embed_dim // 3)
self.num_point_embeddings: int = 2 # pos/neg point
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (image_embedding_size[0], image_embedding_size[1],
image_embedding_size[2])
self.mask_downscaling = nn.Sequential(
nn.Conv3d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm3d(mask_in_chans // 4),
activation(),
nn.Conv3d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm3d(mask_in_chans),
activation(),
nn.Conv3d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0) # 1xXxYxZ
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1],
self.image_embedding_size[2])
return sparse_embeddings, dense_embeddings
class E3D(nn.Module):
mask_threshold: float = 0.0
image_format: str = "L"
def __init__(
self,
image_encoder: ImageEncoderViT3D,
prompt_encoder: PromptEncoder3D,
mask_decoder: MaskDecoder3D,
pixel_mean: List[float] = [123.675],
pixel_std: List[float] = [58.395],
) -> None:
"""
E3D predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-3:],
original_size=image_record["original_size"],
)
masks = masks > self.mask_threshold
outputs.append({
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
})
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size,
self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1], :input_size[2]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
d, h, w = x.shape[-3:]
padd = self.image_encoder.img_size - d
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh, 0, padd))
return x
class S3dClassifier(nn.Module):
def __init__(self):
super(S3dClassifier, self).__init__()
self.encoder_3d = ImageEncoderViT3D(
img_size=128,
patch_size=16,
in_chans=1,
out_chans = 384,
num_heads = 12,
embed_dim=768,
depth=12,
qkv_bias=True,
use_rel_pos=True,
window_size=14,
mlp_ratio=4.0,
global_attn_indexes = [2, 5, 8, 11],
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
)
self.relu = nn.GELU()
self.linear_1 = nn.Linear(768*8*8*8, 768)
self.linear_2 = nn.Linear(768, 768*10)
self.linear_3 = nn.Linear(768*10, 300)
self.classifier = nn.Linear(300, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.encoder_3d(x)
x = x.reshape(x.size(0), -1)
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
x = self.relu(x)
x = self.linear_3(x)
x = self.relu(x)
x = self.classifier(x)
x = self.sigmoid(x)
x = x.squeeze(1)
return x
class D2dClassifier(nn.Module):
def __init__(self, pretrain_dir=None):
super(D2dClassifier, self).__init__()
self.dino = AutoModel.from_pretrained(pretrain_dir)
self.relu = nn.GELU()
self.linear_1 = nn.Linear(401*768, 768)
self.linear_2 = nn.Linear(768, 768*10)
self.linear_3 = nn.Linear(768*10, 300)
self.classifier = nn.Linear(300, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.dino(x)[0]
x = x.reshape(x.size(0), -1)
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
x = self.relu(x)
x = self.linear_3(x)
x = self.relu(x)
x = self.classifier(x)
x = self.sigmoid(x)
x = x.squeeze(1)
return x
import sys, os
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != current_dir.name:
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
import os
import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
import xgboost as xgb
import torch
import torch.nn as nn
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
from pathlib import Path
from cls_utils.log_utils import get_logger
from sklearn.model_selection import train_test_split
logger = get_logger(log_file="/df_lung/ai-project/cls_train/log/data/train_test_data_3d_feat.log")
class LR:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None):
self.in_feature = in_feature
self.n_class = n_class
self.cls = LogisticRegression(max_iter=1000)
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
result = self.cls.predict_proba(np_data)
return result
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
with open(current_epoch_file, 'wb') as f:
pickle.dump(self.cls, f)
if self.logger:
self.logger.info(f"LR saved to {current_epoch_file}")
return current_epoch_file
class DT:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None):
self.in_feature = in_feature
self.n_class = n_class
self.cls = DecisionTreeClassifier()
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
return self.cls.predict_proba(np_data)
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
with open(current_epoch_file, 'wb') as f:
pickle.dump(self.cls, f)
if self.logger:
self.logger.info(f"DT saved to {current_epoch_file}")
return current_epoch_file
class NB:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None):
self.in_feature = in_feature
self.n_class = n_class
self.cls = GaussianNB()
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
return self.cls.predict_proba(np_data)
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
with open(current_epoch_file, 'wb') as f:
pickle.dump(self.cls, f)
if self.logger:
self.logger.info(f"NB saved to {current_epoch_file}")
return current_epoch_file
class RF:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None, n_estimators=100):
self.in_feature = in_feature
self.n_class = n_class
self.cls = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
return self.cls.predict_proba(np_data)
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
with open(current_epoch_file, 'wb') as f:
pickle.dump(self.cls, f)
if self.logger:
self.logger.info(f"Random Forest saved to {current_epoch_file}")
return current_epoch_file
class GBDT:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None, n_estimators=100):
self.in_feature = in_feature
self.n_class = n_class
self.cls = GradientBoostingClassifier(n_estimators=n_estimators, random_state=42)
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
return self.cls.predict_proba(np_data)
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
with open(current_epoch_file, 'wb') as f:
pickle.dump(self.cls, f)
if self.logger:
self.logger.info(f"GBDT saved to {current_epoch_file}")
return current_epoch_file
class XGB:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None, n_estimators=100):
self.in_feature = in_feature
self.n_class = n_class
self.cls = xgb.XGBClassifier(n_estimators=n_estimators, use_label_encoder=False, eval_metric='logloss', random_state=42)
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
return self.cls.predict_proba(np_data)
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
self.cls.save_model(current_epoch_file)
if self.logger:
self.logger.info(f"XGBoost saved to {current_epoch_file}")
return current_epoch_file
def load_predict(self, pkl_file):
if pkl_file and os.path.exists(pkl_file):
self.cls.load_model(pkl_file)
else:
raise Exception(f'cannot load file: {pkl_file}')
class Linear:
def __init__(self, in_feature=None, n_class=None, cls_info=None, save_path=None, logger=None):
self.in_feature = in_feature
self.n_class = n_class
self.cls = MLPClassifier(
hidden_layer_sizes=(10 * in_feature, in_feature, 40),
activation='relu',
solver='adam',
max_iter=200,
random_state=42
)
self.cls_info = cls_info
self.save_path = save_path
self.logger = logger
def train(self, train_np_data, train_np_class):
self.cls.fit(train_np_data, train_np_class)
file = self.save(100004)
return file
def predict(self, np_data):
return self.cls.predict_proba(np_data)
def save(self, epoch):
current_epoch_file = f"{self.cls_info}_epoch_{epoch}.pkl"
current_epoch_file = os.path.join(self.save_path, current_epoch_file)
with open(current_epoch_file, 'wb') as f:
pickle.dump(self.cls, f)
if self.logger:
self.logger.info(f"MLP saved to {current_epoch_file}")
return current_epoch_file
def load_predict(self, pkl_file):
if pkl_file and os.path.exists(pkl_file):
with open(pkl_file, 'rb') as f:
self.cls = pickle.load(f)
else:
raise Exception(f'cannot load file: {pkl_file}')
class FeatCls:
def __init__(self, classification_name=None, n_estimators=100, in_feature_col_list=None, class_col=None, n_class=None, pkl_file=None, cls_task_info=None, usg=None, save_dir=None, log_dir=None):
self.classification_name = classification_name
self.n_estimators = n_estimators
self.in_feature_col_list = in_feature_col_list
self.in_feature = len(self.in_feature_col_list)
self.class_col = class_col
self.n_class = n_class
self.pkl_file = pkl_file
self.cls_task_info = cls_task_info
self.usg = usg
self.save_path = os.path.join(save_dir, self.cls_task_info)
self.log_file = os.path.join(log_dir, f"{self.usg}_{self.cls_task_info}.log")
Path(self.save_path).mkdir(parents=True, exist_ok=True)
self.logger = get_logger(self.log_file)
self.model = self.get_model()
self.predict_model = self.load_predict()
def get_model(self):
self.support_classification = {
"lr": LR(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=self.cls_task_info,
save_path=self.save_path,
logger=self.logger
),
"dt": DT(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=self.cls_task_info,
save_path=self.save_path,
logger=self.logger
),
"nb": NB(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=self.cls_task_info,
save_path=self.save_path,
logger=self.logger
),
"rf": RF(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=self.cls_task_info,
save_path=self.save_path,
logger=self.logger,
n_estimators=self.n_estimators
),
"gbdt": GBDT(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=self.cls_task_info,
save_path=self.save_path,
logger=self.logger,
n_estimators=self.n_estimators
),
"xgb": XGB(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=self.cls_task_info,
save_path=self.save_path,
logger=self.logger,
n_estimators=self.n_estimators
),
"linear": Linear(
in_feature=self.in_feature,
n_class=self.n_class,
cls_info=f"{self.cls_task_info}_mlp",
save_path=self.save_path,
logger=self.logger
),
}
if self.classification_name not in self.support_classification:
raise Exception(f"Classification name not supported: {self.classification_name}")
return self.support_classification[self.classification_name]
def load_predict(self):
if self.classification_name != 'xgb' and self.pkl_file and os.path.exists(self.pkl_file) and self.pkl_file.endswith('pkl'):
with open(self.pkl_file, 'rb') as f:
model = pickle.load(f)
return model
elif self.pkl_file and os.path.exists(self.pkl_file) and '.pt' in self.pkl_file:
return self.model.load_state_dict(torch.load(self.pkl_file))
elif self.classification_name == 'xgb' and self.pkl_file:
self.model.load_predict(self.pkl_file)
return self.model
elif self.pkl_file:
raise Exception(f'cannot load file: {self.pkl_file}')
def load_data(self, csv_file):
if isinstance(csv_file, str) and os.path.isfile(csv_file):
return pd.read_csv(csv_file, header=0)
return csv_file
def train(self, train_data):
self.train_df = self.load_data(train_data)
train_input_feat_df = self.train_df[self.in_feature_col_list]
train_input_class_df = self.train_df[self.class_col]
train_input_np_feat = train_input_feat_df.to_numpy()
train_input_np_class = train_input_class_df.to_numpy()
self.pkl_file = self.model.train(train_input_np_feat, train_input_np_class)
self.logger.info(f"save to {self.pkl_file}")
return self.pkl_file
def test(self, test_data):
test_df = self.load_data(test_data)
test_input_feat_df = test_df[self.in_feature_col_list]
test_input_class_df = test_df[self.class_col]
test_input_np_feat = test_input_feat_df.to_numpy()
test_input_np_class = test_input_class_df.to_numpy()
test_predict_prob = self.predict_model.predict(test_input_np_feat)
if len(test_predict_prob.shape) > 1 and test_predict_prob.shape[1] > 1:
test_predict_prob = np.argmax(test_predict_prob, axis=1)
classification_resport = classification_report(
test_input_np_class, test_predict_prob,
zero_division=0,
output_dict=False
)
if self.logger:
self.logger.info(f"{self.usg}_{self.cls_task_info}, classification_resport:\n{classification_resport}")
return classification_resport
def predict(self, np_feat):
result = self.predict_model.predict(np_feat)
return result.tolist()
n_features = 136
node_list = [2021, 2031, 2041, 1010, 1020, 2011, 2046, 2047, 2048, 2060, 2061, 2062, 3001, 4001, 5001, 6001, 1016]
save_dir = "/df_lung/ai-project/cls_train/cls_ckpt"
train_csv_dir = "/df_lung/cls_train_data/train_csv_data"
log_dir = "/df_lung/ai-project/cls_train/log/train"
for idx_pos_node in node_list:
for idx_neg_node in node_list:
if idx_pos_node == idx_neg_node:
continue
idx_train_file = f"{idx_neg_node}_{idx_pos_node}_data_3d_feature_train.csv"
idx_val_file = f"{idx_neg_node}_{idx_pos_node}_data_3d_feature_val.csv"
idx_test_file = f"{idx_neg_node}_{idx_pos_node}_data_3d_feature_test.csv"
idx_train_file = os.path.join(train_csv_dir, idx_train_file)
idx_val_file = os.path.join(train_csv_dir, idx_val_file)
idx_test_file = os.path.join(train_csv_dir, idx_test_file)
if os.path.exists(idx_train_file) and os.path.exists(idx_test_file):
idx_train_df = pd.read_csv(idx_train_file)
idx_test_df = pd.read_csv(idx_test_file)
else:
logger.info(f"{idx_pos_node}_{idx_neg_node} train_data_3d not exists")
continue
logger.info(f"{idx_neg_node}_{idx_pos_node}, train: {idx_train_df['class'].value_counts()}, test: {idx_test_df['class'].value_counts()}")
idx_feature_list = idx_train_df.columns.tolist()
idx_feature_list = [idx for idx in idx_feature_list if idx.endswith('zscore')]
for idx_classification_name in ['lr', 'dt', 'nb', 'rf', 'gbdt', 'xgb', 'linear']:
idx_train_cls = FeatCls(
classification_name=idx_classification_name,
in_feature_col_list=idx_feature_list,
class_col='class',
n_class=2,
pkl_file=None,
cls_task_info=f"{idx_classification_name}_cls_20241215",
usg="train",
save_dir=save_dir,
log_dir=log_dir
)
print(f"{idx_neg_node}_{idx_pos_node}_{idx_classification_name} start train")
idx_pkl_file = idx_train_cls.train(idx_train_df)
logger.info(f"{idx_neg_node}_{idx_pos_node}_{idx_classification_name}, idx_pkl_file: {idx_pkl_file}")
idx_test_cls = FeatCls(
classification_name=idx_classification_name,
in_feature_col_list=idx_feature_list,
class_col='class',
n_class=2,
pkl_file=idx_pkl_file,
cls_task_info=f"{idx_classification_name}_cls_20241215",
usg="test",
save_dir=save_dir,
log_dir=log_dir
)
print(f"{idx_neg_node}_{idx_pos_node}_{idx_classification_name} start test ")
idx_classification_report = idx_test_cls.test(idx_test_df)
print(f"classification_report:\n{idx_classification_report}")
with open("test_result.txt", 'a') as f:
f.write(f"{idx_neg_node}_{idx_pos_node}_{idx_classification_name}, classification_report:\n{idx_classification_report}\n\n\n\n")
logger.info(f"{idx_neg_node}_{idx_pos_node}_{idx_classification_name}, classification_report:\n{idx_classification_report}")
import os
import sys
import argparse
import logging
import numpy as np
import torch
import time
import re
import copy
import cv2
import random
from matplotlib import pyplot as plt
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from data.db import read_series_dicom, process_single_ct
from pytorch_train.torch_model import TorchModel
from pytorch_train.torch_model_2d import TorchModel_2d
#from pytorch_train.train import cfg
from cls_utils.data_utils import crop_ct_data, get_crop_data_padding, get_crop_data_2d
from cls_utils.utils import hu_value_to_uint8, normalize, base64_to_list, hu_normalize
from data.db import select_signal_series, select_series_by_node_time, get_all_contours_by_labelId, extract_error_label
from cls_utils.data import load_json, train_all_label_id, get_csv_all_label_ids_bylabel
from cls_utils.augement import generate_general_permute_keys, generate_general_indexs, permute_data, augment_data
# cfg = load_json("/home/lung/ai-project/cls_train/config/predict.json")
cfg = load_json("/df_lung/ai-project/cls_train/config/predict.json")
parser = argparse.ArgumentParser(description='predict data')
parser.add_argument('--GPU', default='0', type=str, help='GPU index')
#通过一个validation_log文件,统计指定数据内的准确率
def compute_accuracy(log_path, threshold=0.5, positive=False, train=True):
with open(log_path, 'r') as f:
log_contents = f.readlines()
error_num = 0
sum_num = 0
#检索每一行数据,找到每个结节最后一行的均值结果
for line in log_contents:
match = re.search(r"result: \[(.*?)\]\n", line)
if match:
sum_num += 1
result = match.group(1)
result = float(result)
#print(result)
if positive:
if result < threshold:
error_num += 1
else:
if result > threshold:
error_num += 1
label = '正类' if positive else '负类'
train_or_test = '训练' if train else '测试'
logging.info("{}{}, 阈值:{}, 总个数: {:3}, 出错个数: {:3}, summary result 准确率: {:5}%".format(train_or_test, label, round(threshold, 2), sum_num, error_num, round(1-error_num/ sum_num, 4)*100))
#print("总个数: {}, 出错个数: {}, summary result 准确率:{}%".format(sum_num, error_num, (1-round(error_num/ sum_num, 4))*100))
def load_all_pretrain_ckpt(base):
"""
加载指定文件夹下的所有的pretrain_ckpt文件
例如加载文件./cls_train/best_cls_0704/cls_1010_2046/cls_1010_2046_1/cls_1010_204620230708-1710.ckpt,
只需将base=‘./cls_train/best_cls_0704/cls_1010_2046’即可
"""
all_cakpt_path = {}
result = []
for root, dirs, names in os.walk(base):
for name in names:
index = root.split('/')[-1].split('_')[-1]
path = os.path.join(root, name)
all_cakpt_path[int(index)] = path
for key in sorted(all_cakpt_path):
result.append(all_cakpt_path[key])
return result
#将出现结节的每个切面都当作中心面,其余部分都进行无效填充,进行预测
#seg=True 则该切面图像中只保留分割出来的结节,别的区域填充为无效值
#is_2d=True 则采用2d模型的要求准备数据
def predict_all_series(model, folder_name, select_box=None, label_id=None, seg=False, is_2d=False, threshold=0):
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
patient_id = folder_name.split('-')[0]
z_min = int(select_box[0, 0])
z_max = int(select_box[0, 1])
contours = get_all_contours_by_labelId(label_id) if seg else None
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
origin_data_2d = data
#sum_result = torch.zeros(1,1)
eliminate_num = int((threshold * (z_max - z_min + 1)) / 2)
#初始化一个输入变量
in_data = torch.zeros((1, 8, 256, 256))
for z_index in range(z_min + eliminate_num, z_max - eliminate_num + 1):
temp_select_box = copy.deepcopy(select_box)
temp_select_box[0, 0], temp_select_box[0, 1] = z_index, z_index
if is_2d:
data_2d = origin_data_2d[z_index]
original_data = get_crop_data_2d(data=data_2d, select_box=temp_select_box,
crop_size=cfg['train_crop_size_2d'])
else:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=temp_select_box,
crop_size=cfg['train_crop_size'])
data = hu_normalize(original_data)
#data = normalize(hu_value_to_uint8(original_data))
data = np.tile(data, (1, 8, 1, 1))
data = torch.from_numpy(data).type(torch.float32)
in_data = torch.cat((in_data, data), dim=0)
#直接将其中切面凭借到一起,输入然后取结果求均值
result = model.predict(in_data[1:])
logging.info('time: {}, patiend_id: {}, result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, result))
result = np.mean(result)
#mean_result = sum_result/(z_max - z_min + 1 - 2 * eliminate_num)
return patient_id, result
def predict_all_series_orignal(model, folder_name, select_box=None, label_id=None, seg=False, is_2d=False, threshold=0):
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
patient_id = folder_name.split('-')[0]
z_min = int(select_box[0, 0])
z_max = int(select_box[0, 1])
contours = get_all_contours_by_labelId(label_id) if seg else None
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
origin_data_2d = data
sum_result = torch.zeros(1, 1)
#sum_result = torch.zeros(1,1)
eliminate_num = int((threshold * (z_max - z_min + 1)) / 2)
for z_index in range(z_min + eliminate_num, z_max - eliminate_num + 1):
temp_select_box = copy.deepcopy(select_box)
temp_select_box[0, 0], temp_select_box[0, 1] = z_index, z_index
if is_2d:
data_2d = origin_data_2d[z_index]
original_data = get_crop_data_2d(data=data_2d, select_box=temp_select_box,
crop_size=cfg['train_crop_size_2d'])
#plt.imsave(f'/home/lung/project/ai-project/cls_train/log/image/03/{z_index}.png', original_data, cmap='gray')
#np.save(f'/home/lung/project/ai-project/cls_train/log/npy/03/{z_index}.npy', original_data)
#cv2.imwrite(f'/home/lung/project/ai-project/cls_train/log/image/01/{z_index}.png', original_data)
else:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=temp_select_box,
crop_size=cfg['train_crop_size'])
data = hu_normalize(original_data)
#data = normalize(hu_value_to_uint8(original_data))
data = np.tile(data, (1, 8, 1, 1))
data = torch.from_numpy(data).type(torch.float32)
#直接将其中切面凭借到一起,输入然后取结果求均值
result = model.predict(data)
sum_result = sum_result + result
logging.info('time: {}, patiend_id: {}, z_index: {}, result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, z_index, result))
mean_result = sum_result/(z_max - z_min + 1 - 2 * eliminate_num)
return patient_id, mean_result
#调用3d模型进行预测
def predict_3d(model, label_id, folder_name, select_box=None, seg=False):
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
z_min = int(select_box[0, 0])
z_max = int(select_box[0, 1])
contours = get_all_contours_by_labelId(label_id) if seg else None
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
result = predict(model, ct_data, select_box)
return result
#模型预测
#将模型参数文件的地址传递进来,将要进行预测的dicom的文件夹输入进来,
#将所对应的select_box输入进来,最后进行模型预测,输出结果
def predict(model, ct_data, select_box=None):
#处理成(48, 256, 256)的数据
if select_box[0, 0] == select_box[0, 1]:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=select_box,
crop_size=cfg['train_crop_size'])
else:
print("预测")
original_data = crop_ct_data(ct_data=ct_data, select_box=select_box,
crop_size=cfg['train_crop_size'])
#直接读取处理完的npy文件进行预测输出
#npy_path = './cls_train/data/npy_data'
#original_data = get_data_from_file(npy_path=npy_path ,subject_id='cls_1010/395.npy')
#original_data = original_data[np.newaxis]
#print(original_data.shape)
original_data = augment_data(original_data)
original_data = original_data[np.newaxis]
keys = generate_general_permute_keys()
indexes_list = generate_general_indexs()
predict_num = len(indexes_list)
batch_size = 1
n_channels = 1
result = None
for start_size in range(0, predict_num, batch_size):
"""copy_data = original_data_3d.copy()
data = torch.from_numpy(copy_data)
#将数据平移
lower_bound, upper_bound = -10, 10
shift_x, shift_y = random.randint(lower_bound, upper_bound), random.randint(lower_bound, upper_bound)
new_array = np.full(data.shape, -1000, dtype=float)
start_x, start_y = max(0, shift_x), max(0, shift_y)
end_x, end_y = min(256, 256 + shift_x), min(256, 256 + shift_y)
#print(start_x, start_y, end_x, end_y)
new_array[:, start_x:end_x, start_y:end_y] = data[:, max(0, -shift_x):min(256, 256-shift_x), max(0, -shift_y):min(256, 256-shift_y)]
original_data = new_array
original_data = augment_data(original_data)
original_data = original_data[np.newaxis]"""
length = min(batch_size, predict_num - start_size)
cnn_datas = []
for i in range(length):
indexes = indexes_list[start_size + i]
cnn_data = []
for j in indexes[:n_channels]:
data = permute_data(original_data, keys[j])
data = data.transpose(0, 3, 1, 2)
data = data[:, 104:152, :, :]
cnn_data.append(data)
cnn_datas.append(cnn_data)
cnn_datas = cnn_datas[0]
cnn_datas = np.array(cnn_datas, np.float32)
#cnn_datas = torch.from_numpy(cnn_datas)
cnn_datas = hu_normalize(cnn_datas)
cnn_datas = torch.from_numpy(cnn_datas)
cnn_datas = cnn_datas.to('cuda')
temp_result = model.predict(cnn_datas)
temp_result = torch.from_numpy(temp_result)
if result == None:
result = temp_result
else:
result = torch.cat((result, temp_result), 0)
#print(result)
result = torch.mean(result, dim=0).numpy()
"""#对数据进行归一化
data = hu_normalize(original_data)
#data = model.normalize(original_data)
#将数据进行指定处理
data = np.tile(data, (1, 1, 1, 1, 1))
#print(data.shape)
data = torch.from_numpy(data)
#data = get_data(data)
#print(data.shape)
data = data.type(torch.float32)
result = model.predict(data)"""
return result
#将数据库中指定的node_time全部找出来,并全部进行模型预测,mode=None表示对该检测框的每个切面都预测
def predict_all_train_data(models, node_time, mode=None, start_label_id=0, end_label_id=0, seg=False, is_2d=False, threshold=0):
folder_names, select_boxs, label_ids, patient_ids, series_instance_uids= select_series_by_node_time(node_time)
select_boxs = np.array(select_boxs)
error_num = 0
logging.info('time: {}, node_time : {} 训练集测试结果展示:\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), node_time))
print(len(folder_names))
for index in range(len(folder_names)):
print(index)
if label_ids[index] > start_label_id and label_ids[index] < 7764:
#if label_ids[index] > start_label_id:
if mode is None:
folder_name = str(patient_ids[index])+'-'+str(series_instance_uids[index])
sum_result = 0
for i in range(len(models)):
patient_id, mean_result = predict_all_series_orignal(models[i], folder_name,
select_boxs[index], label_id=label_ids[index],
seg=seg, is_2d=is_2d, threshold=threshold)
#sum_result += mean_result
#print(type(index))
logging.info(
'time: {}, patiend_id: {}, label_id: {}, model_{},result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[index], i+1, mean_result))
"""summary_result = sum_result/len(models)
logging.info(
'time: {}, patiend_id: {}, label_id: {}, summary result: {}\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[int(index)], summary_result))
if summary_result <= 0.5:
error_num += 1"""
else:
result = predict(models, folder_names[index], select_boxs[index])
print(folder_names[index].split('-')[0], ' : ', result[0, 0])
#print(folder_names[index], ' label_id: ', label_ids[index])
#predict(model, folder_names[index], select_boxs[index])
logging.info(
'time: {}, 预测总个数: {}, 出错个数: {}, 正确率:{}%'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), len(label_ids), error_num, (1 - round(error_num / len(label_ids), 4))*100))
def predict_all_train_data_orignal(models, node_time, mode=None, start_label_id=0, end_label_id=0, seg=False, is_2d=False, threshold=0):
folder_names, select_boxs, label_ids, patient_ids, series_instance_uids= select_series_by_node_time(node_time)
select_boxs = np.array(select_boxs)
error_num = 0
logging.info('time: {}, node_time : {} 训练集测试结果展示:\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), node_time))
print(len(folder_names))
for index in range(len(folder_names)):
print(index)
if label_ids[index] > start_label_id and label_ids[index] < 7764:
#if label_ids[index] > start_label_id:
if mode is None:
folder_name = str(patient_ids[index])+'-'+str(series_instance_uids[index])
sum_result = 0
for i in range(len(models)):
patient_id, mean_result = predict_all_series(models[i], folder_name,
select_boxs[index], label_id=label_ids[index],
seg=seg, is_2d=is_2d, threshold=threshold)
sum_result += mean_result[0,0]
#print(type(index))
logging.info(
'time: {}, patiend_id: {}, label_id: {}, model_{},result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[index], i+1, mean_result[0, 0]))
summary_result = sum_result/len(models)
logging.info(
'time: {}, patiend_id: {}, label_id: {}, summary result: {}\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[int(index)], summary_result))
if summary_result >= 0.5:
error_num += 1
else:
result = predict(models, folder_names[index], select_boxs[index])
print(folder_names[index].split('-')[0], ' : ', result[0, 0])
#print(folder_names[index], ' label_id: ', label_ids[index])
#predict(model, folder_names[index], select_boxs[index])
logging.info(
'time: {}, 预测总个数: {}, 出错个数: {}, 正确率:{}%'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), len(label_ids), error_num, (1 - round(error_num / len(label_ids), 4))*100))
#通过输入指定label_id进行模型预测
def predict_by_label_id(model, label_id, mode=None, seg=False):
#dicom文件所在的文件夹名字
folder_name, select_box, patient_id, series_instance_uid = select_signal_series(label_id=label_id)
folder_name = str(patient_id)+'-'+str(series_instance_uid)
select_box = np.array(select_box)
if mode is None:
predict_all_series(model, folder_name, select_box,
label_id=label_id, seg=seg)
else:
result = predict_3d(model,label_id=label_id, folder_name=folder_name, select_box=select_box, seg=True)
logging.info('time: {}, patiend_id: {}, label_id: {}, result: {}\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_id, result))
return result[0]
#print(patient_id, ' : ', result[0, 0])
#根据指定label_id进行预测
#is_2d=True采用2d模型
def run(args, is_2d=False, threshold=0.4):
""""""
models = []
if is_2d:
pretrain_ckpt_list = load_all_pretrain_ckpt(cfg['pretrain_folder'])
for i in range(len(pretrain_ckpt_list)):
model = TorchModel_2d(pretrain_ckpt_list[i], args.GPU)
models.append(model)
node_times = [1016]
for node_time in node_times:
predict_all_train_data(models, node_time=node_time, mode=None,
start_label_id=7762, seg=True, is_2d=is_2d,
threshold=threshold)
else:
"""log_path_1 = '/home/lung/project/ai-project/log_validation_0819.log'
log_path_2 = '/home/lung/project/ai-project/log_validation_0822.log'
error_label_ids_1 = extract_error_label(log_path_1)
error_label_ids_2 = extract_error_label(log_path_2)
error_label_ids = error_label_ids_1 + error_label_ids_2
error_label_ids = [int(label_id) for label_id in error_label_ids]"""
model = TorchModel(cfg['pretrain_ckpt'], args.GPU)
#统一测试训练数据和测试数据,保证node_times中的第一个是负类,第二个是正类
node_times = cfg['node_times']
csv_path = cfg['csv_path']
for index in range(len(node_times)):
node_list = node_times[index]
#获取所有当前node_list训练数据的label_id
all_train_label_ids = get_csv_all_label_ids_bylabel(csv_path, node_list, label=index)
#获取所有当前node_list测试数据的label_id
all_label_ids = []
for node_time in node_list:
_, _, label_ids, _, _ = select_series_by_node_time(node_time)
all_label_ids += label_ids
all_test_label_ids = [label_id for label_id in all_label_ids if label_id not in all_train_label_ids and label_id > 432]
logging.info('类别:{} 训练数据集------------------------------------\n'.format(index))
sum = 0
error = 0
for label_id in all_train_label_ids:
sum += 1
result = predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
if (index == 0 and result > 0.5) or (index == 1 and result < 0.5):
error += 1
logging.info('time: {}, 总数: {}, 出错个数: {}, 正确率: {}%\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), sum, error,(1 - round(error / sum, 4))*100))
logging.info('类别:{} 测试数据集------------------------------------\n'.format(index))
sum = 0
error = 0
for label_id in all_test_label_ids:
sum += 1
result = predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
if (index == 0 and result > 0.5) or (index == 1 and result < 0.5):
error += 1
logging.info('time: {}, 总数: {}, 出错个数: {}, 正确率: {}%\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), sum, error,(1 - round(error / sum, 4))*100))
#all_train_label_ids_n = get_csv_all_label_ids_bylabel(csv_path, node_list, label=1)
"""all_label_ids = []
for node_time in node_times:
_, _, label_ids, _, _= select_series_by_node_time(node_time)
all_label_ids += label_ids"""
"""all_label_ids = [8065]
sum = 0
error = 0
for label_id in all_label_ids:
if label_id > 432 :
sum += 1
result = predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
if result < 0.5:
error += 1
logging.info('time: {}, 总数: {}, 出错个数: {}, 正确率: {}%\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), sum, error,(1 - round(error / sum, 4))*100))"""
def main():
logging.basicConfig(level=logging.INFO, filename=cfg['validation_filename'], filemode='a')
args = parser.parse_args()
run(args, is_2d=False)
def summary_acc():
logging.basicConfig(level=logging.INFO, filename=cfg['compute_accuracy_filename'], filemode='a')
validation_folder = '/home/lung/ai-project/cls_train/log/validation/cls_234567_2031/20240815'
result_1_train_log_path = os.path.join(validation_folder, '1_train.log')
result_1_test_log_path = os.path.join(validation_folder, '1_test.log')
result_0_train_log_path = os.path.join(validation_folder, '0_train.log')
result_0_test_log_path = os.path.join(validation_folder, '0_test.log')
for threshold in np.arange(0.4, 0.9, 0.01):
logging.info("------------------------------------------------------------------------")
compute_accuracy(result_1_train_log_path, threshold, positive=True, train=True)
compute_accuracy(result_0_train_log_path, threshold, positive=False, train=True)
compute_accuracy(result_1_test_log_path, threshold, positive=True, train=False)
compute_accuracy(result_0_test_log_path, threshold, positive=False, train=False)
if __name__ == '__main__':
#main()
summary_acc()
# coding=utf-8
import sys, os
cur_path = os.path.dirname(os.path.abspath(__file__))
while os.path.basename(cur_path) != 'cls_train':
cur_path = os.path.dirname(cur_path)
sys.path.append(cur_path)
import argparse
import logging
import numpy as np
import torch
import time
import re
import copy
import cv2
import random
from matplotlib import pyplot as plt
import csv
import pandas as pd
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from data.db import read_series_dicom, process_single_ct
from pytorch_train.torch_model import TorchModel
from pytorch_train.torch_model_2d import TorchModel_2d
#from pytorch_train.train import cfg
from cls_utils.data_utils import crop_ct_data, get_crop_data_padding, get_crop_data_2d
from cls_utils.utils import hu_value_to_uint8, normalize, base64_to_list, hu_normalize
from data.db import select_signal_series, select_series_by_node_time, get_all_contours_by_labelId, extract_error_label
from cls_utils.data import load_json, train_all_label_id, get_csv_all_label_ids_bylabel
from cls_utils.augement import generate_general_permute_keys, generate_general_indexs, permute_data, augment_data
cfg = load_json("/df_lung/ai-project/cls_train/config/predict_20241112.json")
parser = argparse.ArgumentParser(description='predict data')
parser.add_argument('--GPU', default='0', type=str, help='GPU index')
# 通过一个validation_log文件,统计指定数据内的准确率
def compute_accuracy(log_path, threshold=0.5, positive=False, train=True):
with open(log_path, 'r') as f:
log_contents = f.readlines()
error_num = 0
sum_num = 0
#检索每一行数据,找到每个结节最后一行的均值结果
for line in log_contents:
match = re.search(r"result: \[(.*?)\]\n", line)
if match:
sum_num += 1
result = match.group(1)
result = float(result)
#print(result)
if positive:
if result < threshold:
error_num += 1
else:
if result > threshold:
error_num += 1
label = '正类' if positive else '负类'
train_or_test = '训练' if train else '测试'
logging.info("{}{}, 阈值:{}, 总个数: {:3}, 出错个数: {:3}, summary result 准确率: {:5}%".format(train_or_test, label, round(threshold, 2), sum_num, error_num, round(1-error_num/ sum_num, 4)*100))
#print("总个数: {}, 出错个数: {}, summary result 准确率:{}%".format(sum_num, error_num, (1-round(error_num/ sum_num, 4))*100))
def load_all_pretrain_ckpt(base):
"""
加载指定文件夹下的所有的pretrain_ckpt文件
例如加载文件./cls_train/best_cls_0704/cls_1010_2046/cls_1010_2046_1/cls_1010_204620230708-1710.ckpt,
只需将base=‘./cls_train/best_cls_0704/cls_1010_2046’即可
"""
all_cakpt_path = {}
result = []
for root, dirs, names in os.walk(base):
for name in names:
index = root.split('/')[-1].split('_')[-1]
path = os.path.join(root, name)
all_cakpt_path[int(index)] = path
for key in sorted(all_cakpt_path):
result.append(all_cakpt_path[key])
return result
#将出现结节的每个切面都当作中心面,其余部分都进行无效填充,进行预测
#seg=True 则该切面图像中只保留分割出来的结节,别的区域填充为无效值
#is_2d=True 则采用2d模型的要求准备数据
def predict_all_series(model, folder_name, select_box=None, label_id=None, seg=False, is_2d=False, threshold=0):
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
patient_id = folder_name.split('-')[0]
z_min = int(select_box[0, 0])
z_max = int(select_box[0, 1])
contours = get_all_contours_by_labelId(label_id) if seg else None
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
origin_data_2d = data
#sum_result = torch.zeros(1,1)
eliminate_num = int((threshold * (z_max - z_min + 1)) / 2)
#初始化一个输入变量
in_data = torch.zeros((1, 8, 256, 256))
for z_index in range(z_min + eliminate_num, z_max - eliminate_num + 1):
temp_select_box = copy.deepcopy(select_box)
temp_select_box[0, 0], temp_select_box[0, 1] = z_index, z_index
if is_2d:
data_2d = origin_data_2d[z_index]
original_data = get_crop_data_2d(data=data_2d, select_box=temp_select_box,
crop_size=cfg['train_crop_size_2d'])
else:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=temp_select_box,
crop_size=cfg['train_crop_size'])
data = hu_normalize(original_data)
#data = normalize(hu_value_to_uint8(original_data))
data = np.tile(data, (1, 8, 1, 1))
data = torch.from_numpy(data).type(torch.float32)
in_data = torch.cat((in_data, data), dim=0)
#直接将其中切面凭借到一起,输入然后取结果求均值
result = model.predict(in_data[1:])
logging.info('time: {}, patiend_id: {}, result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, result))
result = np.mean(result)
#mean_result = sum_result/(z_max - z_min + 1 - 2 * eliminate_num)
return patient_id, result
def predict_all_series_orignal(model, folder_name, select_box=None, label_id=None, seg=False, is_2d=False, threshold=0):
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
patient_id = folder_name.split('-')[0]
z_min = int(select_box[0, 0])
z_max = int(select_box[0, 1])
contours = get_all_contours_by_labelId(label_id) if seg else None
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
origin_data_2d = data
sum_result = torch.zeros(1, 1)
#sum_result = torch.zeros(1,1)
eliminate_num = int((threshold * (z_max - z_min + 1)) / 2)
for z_index in range(z_min + eliminate_num, z_max - eliminate_num + 1):
temp_select_box = copy.deepcopy(select_box)
temp_select_box[0, 0], temp_select_box[0, 1] = z_index, z_index
if is_2d:
data_2d = origin_data_2d[z_index]
original_data = get_crop_data_2d(data=data_2d, select_box=temp_select_box,
crop_size=cfg['train_crop_size_2d'])
#plt.imsave(f'/home/lung/project/ai-project/cls_train/log/image/03/{z_index}.png', original_data, cmap='gray')
#np.save(f'/home/lung/project/ai-project/cls_train/log/npy/03/{z_index}.npy', original_data)
#cv2.imwrite(f'/home/lung/project/ai-project/cls_train/log/image/01/{z_index}.png', original_data)
else:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=temp_select_box,
crop_size=cfg['train_crop_size'])
data = hu_normalize(original_data)
#data = normalize(hu_value_to_uint8(original_data))
data = np.tile(data, (1, 8, 1, 1))
data = torch.from_numpy(data).type(torch.float32)
#直接将其中切面凭借到一起,输入然后取结果求均值
result = model.predict(data)
sum_result = sum_result + result
logging.info('time: {}, patiend_id: {}, z_index: {}, result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, z_index, result))
mean_result = sum_result/(z_max - z_min + 1 - 2 * eliminate_num)
return patient_id, mean_result
#调用3d模型进行预测
def predict_3d(model, label_id, folder_name, select_box=None, seg=False):
dicom_folder = os.path.join(cfg['dicom_folder'], folder_name)
ct_data = read_series_dicom(dicom_folder=dicom_folder)
z_min = int(select_box[0, 0])
z_max = int(select_box[0, 1])
contours = get_all_contours_by_labelId(label_id) if seg else None
if contours is not None:
data = ct_data.get_raw_image()
img_np = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for i in range(z_max-z_min+1):
_, _, img = base64_to_list(contours[i])
img_np[z_min+i] = img
data[img_np == 0] = -1000
ct_data.set_raw_image = data
result = predict(model, ct_data, select_box)
return result
#模型预测
#将模型参数文件的地址传递进来,将要进行预测的dicom的文件夹输入进来,
#将所对应的select_box输入进来,最后进行模型预测,输出结果
def predict(model, ct_data, select_box=None):
#处理成(48, 256, 256)的数据
if select_box[0, 0] == select_box[0, 1]:
original_data = get_crop_data_padding(ct_data=ct_data, select_box=select_box,
crop_size=cfg['train_crop_size'])
else:
print("预测")
original_data = crop_ct_data(ct_data=ct_data, select_box=select_box,
crop_size=cfg['train_crop_size'])
#直接读取处理完的npy文件进行预测输出
#npy_path = './cls_train/data/npy_data'
#original_data = get_data_from_file(npy_path=npy_path ,subject_id='cls_1010/395.npy')
#original_data = original_data[np.newaxis]
#print(original_data.shape)
original_data = augment_data(original_data)
original_data = original_data[np.newaxis]
keys = generate_general_permute_keys()
indexes_list = generate_general_indexs()
predict_num = len(indexes_list)
batch_size = 1
n_channels = 1
result = None
for start_size in range(0, predict_num, batch_size):
"""copy_data = original_data_3d.copy()
data = torch.from_numpy(copy_data)
#将数据平移
lower_bound, upper_bound = -10, 10
shift_x, shift_y = random.randint(lower_bound, upper_bound), random.randint(lower_bound, upper_bound)
new_array = np.full(data.shape, -1000, dtype=float)
start_x, start_y = max(0, shift_x), max(0, shift_y)
end_x, end_y = min(256, 256 + shift_x), min(256, 256 + shift_y)
#print(start_x, start_y, end_x, end_y)
new_array[:, start_x:end_x, start_y:end_y] = data[:, max(0, -shift_x):min(256, 256-shift_x), max(0, -shift_y):min(256, 256-shift_y)]
original_data = new_array
original_data = augment_data(original_data)
original_data = original_data[np.newaxis]"""
length = min(batch_size, predict_num - start_size)
cnn_datas = []
for i in range(length):
indexes = indexes_list[start_size + i]
cnn_data = []
for j in indexes[:n_channels]:
data = permute_data(original_data, keys[j])
data = data.transpose(0, 3, 1, 2)
data = data[:, 104:152, :, :]
cnn_data.append(data)
cnn_datas.append(cnn_data)
cnn_datas = cnn_datas[0]
cnn_datas = np.array(cnn_datas, np.float32)
#print("cnn_datas.shape: ", cnn_datas.shape)
#cnn_datas = torch.from_numpy(cnn_datas)
cnn_datas = hu_normalize(cnn_datas)
cnn_datas = torch.from_numpy(cnn_datas)
cnn_datas = cnn_datas.to('cuda')
#print("cnn_datas shape: ", cnn_datas.shape)
temp_result = model.predict(cnn_datas)
temp_result = torch.from_numpy(temp_result)
if result == None:
result = temp_result
else:
result = torch.cat((result, temp_result), 0)
#print(result)
result = torch.mean(result, dim=0).numpy()
"""#对数据进行归一化
data = hu_normalize(original_data)
#data = model.normalize(original_data)
#将数据进行指定处理
data = np.tile(data, (1, 1, 1, 1, 1))
#print(data.shape)
data = torch.from_numpy(data)
#data = get_data(data)
#print(data.shape)
data = data.type(torch.float32)
result = model.predict(data)"""
return result
#将数据库中指定的node_time全部找出来,并全部进行模型预测,mode=None表示对该检测框的每个切面都预测
def predict_all_train_data(models, node_time, mode=None, start_label_id=0, end_label_id=0, seg=False, is_2d=False, threshold=0):
folder_names, select_boxs, label_ids, patient_ids, series_instance_uids= select_series_by_node_time(node_time)
select_boxs = np.array(select_boxs)
error_num = 0
logging.info('time: {}, node_time : {} 训练集测试结果展示:\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), node_time))
print(len(folder_names))
for index in range(len(folder_names)):
print(index)
if label_ids[index] > start_label_id and label_ids[index] < 7764:
#if label_ids[index] > start_label_id:
if mode is None:
folder_name = str(patient_ids[index])+'-'+str(series_instance_uids[index])
sum_result = 0
for i in range(len(models)):
patient_id, mean_result = predict_all_series_orignal(models[i], folder_name,
select_boxs[index], label_id=label_ids[index],
seg=seg, is_2d=is_2d, threshold=threshold)
#sum_result += mean_result
#print(type(index))
logging.info(
'time: {}, patiend_id: {}, label_id: {}, model_{},result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[index], i+1, mean_result))
"""summary_result = sum_result/len(models)
logging.info(
'time: {}, patiend_id: {}, label_id: {}, summary result: {}\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[int(index)], summary_result))
if summary_result <= 0.5:
error_num += 1"""
else:
result = predict(models, folder_names[index], select_boxs[index])
print(folder_names[index].split('-')[0], ' : ', result[0, 0])
#print(folder_names[index], ' label_id: ', label_ids[index])
#predict(model, folder_names[index], select_boxs[index])
logging.info(
'time: {}, 预测总个数: {}, 出错个数: {}, 正确率:{}%'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), len(label_ids), error_num, (1 - round(error_num / len(label_ids), 4))*100))
def predict_all_train_data_orignal(models, node_time, mode=None, start_label_id=0, end_label_id=0, seg=False, is_2d=False, threshold=0):
folder_names, select_boxs, label_ids, patient_ids, series_instance_uids= select_series_by_node_time(node_time)
select_boxs = np.array(select_boxs)
error_num = 0
logging.info('time: {}, node_time : {} 训练集测试结果展示:\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), node_time))
print(len(folder_names))
for index in range(len(folder_names)):
print(index)
if label_ids[index] > start_label_id and label_ids[index] < 7764:
#if label_ids[index] > start_label_id:
if mode is None:
folder_name = str(patient_ids[index])+'-'+str(series_instance_uids[index])
sum_result = 0
for i in range(len(models)):
patient_id, mean_result = predict_all_series(models[i], folder_name,
select_boxs[index], label_id=label_ids[index],
seg=seg, is_2d=is_2d, threshold=threshold)
sum_result += mean_result[0,0]
#print(type(index))
logging.info(
'time: {}, patiend_id: {}, label_id: {}, model_{},result: {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[index], i+1, mean_result[0, 0]))
summary_result = sum_result/len(models)
logging.info(
'time: {}, patiend_id: {}, label_id: {}, summary result: {}\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_ids[int(index)], summary_result))
if summary_result >= 0.5:
error_num += 1
else:
result = predict(models, folder_names[index], select_boxs[index])
print(folder_names[index].split('-')[0], ' : ', result[0, 0])
#print(folder_names[index], ' label_id: ', label_ids[index])
#predict(model, folder_names[index], select_boxs[index])
logging.info(
'time: {}, 预测总个数: {}, 出错个数: {}, 正确率:{}%'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), len(label_ids), error_num, (1 - round(error_num / len(label_ids), 4))*100))
#通过输入指定label_id进行模型预测
def predict_by_label_id(model, label_id, mode=None, seg=False):
#dicom文件所在的文件夹名字
folder_name, select_box, patient_id, series_instance_uid = select_signal_series(label_id=label_id)
folder_name = str(patient_id)+'-'+str(series_instance_uid)
select_box = np.array(select_box)
if mode is None:
predict_all_series(model, folder_name, select_box,
label_id=label_id, seg=seg)
else:
result = predict_3d(model,label_id=label_id, folder_name=folder_name, select_box=select_box, seg=True)
logging.info('time: {}, patiend_id: {}, label_id: {}, result: {}\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), patient_id, label_id, result))
print(patient_id, ' : ', result[0])
return result[0]
print(patient_id, ' : ', result[0, 0])
#根据指定label_id进行预测
#is_2d=True采用2d模型
def run(args, is_2d=False, threshold=0.4):
""""""
models = []
if is_2d:
pretrain_ckpt_list = load_all_pretrain_ckpt(cfg['pretrain_folder'])
for i in range(len(pretrain_ckpt_list)):
model = TorchModel_2d(pretrain_ckpt_list[i], args.GPU)
models.append(model)
node_times = [1016]
for node_time in node_times:
predict_all_train_data(models, node_time=node_time, mode=None,
start_label_id=7762, seg=True, is_2d=is_2d,
threshold=threshold)
else:
"""log_path_1 = '/home/lung/project/ai-project/log_validation_0819.log'
log_path_2 = '/home/lung/project/ai-project/log_validation_0822.log'
error_label_ids_1 = extract_error_label(log_path_1)
error_label_ids_2 = extract_error_label(log_path_2)
error_label_ids = error_label_ids_1 + error_label_ids_2
error_label_ids = [int(label_id) for label_id in error_label_ids]"""
model = TorchModel(cfg['pretrain_ckpt'], args.GPU)
#统一测试训练数据和测试数据,保证node_times中的第一个是负类,第二个是正类
node_times = cfg['node_times']
csv_path = cfg['csv_path']
for index in range(len(node_times)):
node_list = node_times[index]
#获取所有当前node_list训练数据的label_id
all_train_label_ids = get_csv_all_label_ids_bylabel(csv_path, node_list, label=index)
#获取所有当前node_list测试数据的label_id
all_label_ids = []
for node_time in node_list:
_, _, label_ids, _, _ = select_series_by_node_time(node_time)
all_label_ids += label_ids
all_test_label_ids = [label_id for label_id in all_label_ids if label_id not in all_train_label_ids and label_id > 432]
logging.info('类别:{} 训练数据集------------------------------------\n'.format(index))
sum = 0
error = 0
for label_id in all_train_label_ids:
sum += 1
result = predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
if (index == 0 and result > 0.5) or (index == 1 and result < 0.5):
error += 1
logging.info('time: {}, 总数: {}, 出错个数: {}, 正确率: {}%\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), sum, error,(1 - round(error / sum, 4))*100))
logging.info('类别:{} 测试数据集------------------------------------\n'.format(index))
sum = 0
error = 0
for label_id in all_test_label_ids:
sum += 1
result = predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
if (index == 0 and result > 0.5) or (index == 1 and result < 0.5):
error += 1
logging.info('time: {}, 总数: {}, 出错个数: {}, 正确率: {}%\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), sum, error,(1 - round(error / sum, 4))*100))
#all_train_label_ids_n = get_csv_all_label_ids_bylabel(csv_path, node_list, label=1)
"""all_label_ids = []
for node_time in node_times:
_, _, label_ids, _, _= select_series_by_node_time(node_time)
all_label_ids += label_ids"""
"""all_label_ids = [8065]
sum = 0
error = 0
for label_id in all_label_ids:
if label_id > 432 :
sum += 1
result = predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
if result < 0.5:
error += 1
logging.info('time: {}, 总数: {}, 出错个数: {}, 正确率: {}%\n'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), sum, error,(1 - round(error / sum, 4))*100))"""
def main():
validation_log_file = cfg['validation_filename']
os.makedirs(os.path.dirname(validation_log_file), exist_ok=True)
if not os.path.exists(validation_log_file):
open(validation_log_file, 'a').close()
logging.basicConfig(level=logging.INFO, filename=cfg['validation_filename'], filemode='a')
args = parser.parse_args()
run(args, is_2d=False)
def summary_acc():
log_file = cfg['compute_accuracy_filename']
os.makedirs(os.path.dirname(log_file), exist_ok=True)
if not os.path.exists(log_file):
open(log_file, 'a').close()
logging.basicConfig(level=logging.INFO, filename=cfg['compute_accuracy_filename'], filemode='a')
validation_folder = '/df_lung/ai-project/cls_train/log/validation/cls_234567_2041/20241104'
result_1_train_log_path = os.path.join(validation_folder, '1_train.log')
result_1_test_log_path = os.path.join(validation_folder, '1_test.log')
result_0_train_log_path = os.path.join(validation_folder, '0_train.log')
result_0_test_log_path = os.path.join(validation_folder, '0_test.log')
for threshold in np.arange(0.4, 0.9, 0.01):
logging.info("------------------------------------------------------------------------")
compute_accuracy(result_1_train_log_path, threshold, positive=True, train=True)
compute_accuracy(result_0_train_log_path, threshold, positive=False, train=True)
compute_accuracy(result_1_test_log_path, threshold, positive=True, train=False)
compute_accuracy(result_0_test_log_path, threshold, positive=False, train=False)
def get_csv_from_log(threshold = 0.5):
validation_file = '/df_lung/ai-project/cls_train/log/validation/cls_234567_2041/20241104/log_validation_20241104.log'
# 日志文件路径
log_file_path = validation_file
# CSV文件路径
train_dataset_output_file_path = '训练集推理输出结果.xlsx'
test_dataset_output_file_path = '测试集推理输出结果.xlsx'
# 初始化一个列表来存储数据
train_data_output_list = []
test_data_output_list = []
is_train = True
pos_count = None
# 打开日志文件
with open(log_file_path, 'r', encoding='utf-8') as log_file:
# 逐行读取日志文件
lines = log_file.readlines()
for i in range(0, len(lines)):
if "测试数据集-" in lines[i]:
is_train = False
if "patiend_id" in lines[i] and "label_id" in lines[i] and "result" in lines[i]:
# 提取patient_id
patient_id_start = lines[i].find('patiend_id: ') + len('patiend_id: ')
patient_id_end = lines[i].find(',', patient_id_start)
patient_id = lines[i][patient_id_start:patient_id_end].strip()
# 提取label_id
label_id_start = lines[i].find('label_id: ') + len('label_id: ')
label_id_end = lines[i].find(',', label_id_start)
label_id = lines[i][label_id_start:label_id_end].strip()
# 提取predict_prob
predict_prob_start = lines[i].find('result: [') + len('result: [')
predict_prob_end = lines[i].find(']', predict_prob_start)
predict_prob = lines[i][predict_prob_start:predict_prob_end].strip()
if float(predict_prob) > threshold:
pos_count = 1
else:
pos_count = 0
# 将提取的数据添加到列表中
if is_train:
train_data_output_list.append([str(patient_id), str(label_id), predict_prob, pos_count])
else:
test_data_output_list.append([str(patient_id), str(label_id), predict_prob, pos_count])
# 将列表转换为DataFrame
train_df = pd.DataFrame(train_data_output_list, columns=['patient_id', 'label_id', 'predict_prob', 'pos_count'])
train_df['patient_id'] = train_df['patient_id'].astype(str)
test_df = pd.DataFrame(test_data_output_list, columns=['patient_id', 'label_id', 'predict_prob', 'pos_count'])
test_df['patient_id'] = test_df['patient_id'].astype(str)
# 保存为xlsx文件
train_df.to_excel(train_dataset_output_file_path, index=False)
test_df.to_excel(test_dataset_output_file_path, index=False)
print(f"结果文件写入\n{train_dataset_output_file_path}\n{test_dataset_output_file_path}")
if __name__ == '__main__':
# main()
# summary_acc()
# get_csv_from_log()
model = TorchModel(cfg['pretrain_ckpt'], GPUIndex="0")
print(f"路径: {cfg['pretrain_ckpt']}")
label_id = 6154
predict_by_label_id(model=model, label_id=label_id, mode='3d', seg=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial
__all__ = [
'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnet200'
]
def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
# 3x3x3 convolution with padding
return nn.Conv3d(
in_planes,
out_planes,
kernel_size=3,
dilation=dilation,
stride=stride,
padding=dilation,
bias=False)
def downsample_basic_block(x, planes, stride, no_cuda=False):
out = F.avg_pool3d(x, kernel_size=1, stride=stride)
zero_pads = torch.Tensor(
out.size(0), planes - out.size(1), out.size(2), out.size(3),
out.size(4)).zero_()
if not no_cuda:
if isinstance(out.data, torch.cuda.FloatTensor):
zero_pads = zero_pads.cuda()
out = Variable(torch.cat([out.data, zero_pads], dim=1))
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(
planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
sample_input_D,
sample_input_H,
sample_input_W,
num_seg_classes,
shortcut_type='B',
no_cuda = False):
self.inplanes = 64
self.no_cuda = no_cuda
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(
1,
64,
kernel_size=7,
stride=(2, 2, 2),
padding=(3, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(
block, 128, layers[1], shortcut_type, stride=2)
self.layer3 = self._make_layer(
block, 256, layers[2], shortcut_type, stride=1, dilation=2)
self.layer4 = self._make_layer(
block, 512, layers[3], shortcut_type, stride=1, dilation=4)
self.conv_seg = nn.Sequential(
nn.ConvTranspose3d(
512 * block.expansion,
32,
2,
stride=2
),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(
32,
32,
kernel_size=3,
stride=(1, 1, 1),
padding=(1, 1, 1),
bias=False),
nn.BatchNorm3d(32),
nn.ReLU(inplace=True),
nn.Conv3d(
32,
num_seg_classes,
kernel_size=1,
stride=(1, 1, 1),
bias=False)
)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(
downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
no_cuda=self.no_cuda)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv_seg(x)
return x
class ResNetClassifier(nn.Module):
def __init__(self,
block,
layers,
sample_input_D,
sample_input_H,
sample_input_W,
shortcut_type='B',
num_classes=2,
out_channels=2048,
no_cuda = False):
self.inplanes = 64
self.no_cuda = no_cuda
super(ResNetClassifier, self).__init__()
self.conv1 = nn.Conv3d(
1,
64,
kernel_size=7,
stride=(2, 2, 2),
padding=(3, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(
block, 128, layers[1], shortcut_type, stride=2)
self.layer3 = self._make_layer(
block, 256, layers[2], shortcut_type, stride=1, dilation=2)
self.layer4 = self._make_layer(
block, 512, layers[3], shortcut_type, stride=1, dilation=4)
for m in self.modules():
if isinstance(m, nn.Conv3d):
m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
self.cls = nn.Sequential(
nn.AdaptiveMaxPool3d((1, 1, 1)),
nn.Flatten(),
nn.Linear(in_features=2048, out_features=num_classes-1),
)
self.sigmoid = nn.Sigmoid()
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
if shortcut_type == 'A':
downsample = partial(
downsample_basic_block,
planes=planes * block.expansion,
stride=stride,
no_cuda=self.no_cuda)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm3d(planes * block.expansion))
layers = []
layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.cls(x)
x = self.sigmoid(x)
x = x.squeeze(1)
return x
def resnet10(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
return model
def resnet18(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
def resnet34(**kwargs):
"""Constructs a ResNet-34 model.
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
return model
def resnet50(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
def resnet101(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
return model
def resnet152(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
def resnet200(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
return model
import os
import sys
import torch
from torch.nn import DataParallel
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
#from model.net_cls_1024u_13 import Net
#from net.net_cls_1024u_230425 import Net
from net.net_cls_1024u_231025 import Net
class TorchModel(object):
def __init__(self, model_path, GPUIndex):
super(TorchModel, self).__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = GPUIndex
gpus = len(GPUIndex.split(','))
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = Net(n_channels=1, n_diff_classes=1, n_base_filters=1)
model = Net(n_channels=1, n_diff_classes=1, n_base_filters=8)
#将训练好的模型参数加载到模型内
if 'cuda' == self.device:
model_param = torch.load(model_path)
else:
model_param = torch.load(model_path, map_location=lambda storage, loc: storage)
print(model_path)
model.load_state_dict(model_param['state_dict'], strict=True)
#将模型放到指定设备上
model = model.to(self.device)
#通过多GPU进行模型部署
if 'cuda' == self.device and gpus > 1:
device_ids = list(range(gpus))
model = DataParallel(model, device_ids=device_ids)
print('GPUIndex: ', model.device_ids)
self.original_model = model
def predict(self, data):
self.original_model.eval()
with torch.no_grad():
data = data.to(self.device)
return self.original_model(data).cpu().numpy()
def normalize(self, data, min_value=-1000, max_value=600):
new_data = data
new_data[new_data < min_value] = min_value
new_data[new_data > max_value] = max_value
# normalize to [-1, 1]
new_data = 2.0 * (new_data - min_value) / (max_value - min_value) - 1
return new_data
import os
import sys
import torch
from torch.nn import DataParallel
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
#from model.net_cls_1024u_13 import Net
from net.net_cls_1024u_2d import Net
class TorchModel_2d(object):
def __init__(self, model_path, GPUIndex):
super(TorchModel_2d, self).__init__()
os.environ['CUDA_VISIBLE_DEVICES'] = GPUIndex
gpus = len(GPUIndex.split(','))
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
#print(self.device)
model = Net()
#将训练好的模型参数加载到模型内
if 'cuda' == self.device:
model_param = torch.load(model_path)
else:
model_param = torch.load(model_path, map_location=lambda storage, loc: storage)
print(model_path)
model.load_state_dict(model_param['state_dict'], strict=True)
#将模型放到指定设备上
#model = model.to(self.device)
#通过多GPU进行模型部署
if 'cuda' == self.device:
device_ids = list(range(gpus))
model = DataParallel(model, device_ids=device_ids)
print('GPUIndex: ', model.device_ids)
self.original_model = model
def predict(self, data):
self.original_model.eval()
with torch.no_grad():
data = data.to(self.device)
return self.original_model(data).cpu().sigmoid().numpy()
def normalize(self, data, min_value=-1000, max_value=600):
new_data = data
new_data[new_data < min_value] = min_value
new_data[new_data > max_value] = max_value
# normalize to [-1, 1]
new_data = 2.0 * (new_data - min_value) / (max_value - min_value) - 1
return new_data
\ No newline at end of file
import os
import sys
import time
import argparse
import logging
import glob
import torch
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from torch.optim import lr_scheduler
from torch.optim import Adam
import torch.nn.functional as F
import torch.nn as nn
#from torch.utils.tensorboard import SummaryWriter
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from data.dataset import SubjectDataset
#from model.net_cls_1024u_13 import Net
from net.net_cls_1024u_231025 import Net
#from net.net_cls_1024u_2d import Net
from cls_utils.data import save_summary_data, load_json
import pathlib
cur_path = pathlib.Path.cwd()
#torch.manual_seed(0)
#torch.cuda.manual_seed_all(0)
# cfg = load_json("/home/lung/ai-project/cls_train/config/train.json")
cfg = load_json("/df_lung/ai-project/cls_train/config/train_20241112.json")
parser = argparse.ArgumentParser(description='train data')
parser.add_argument('--num_workers', default=1, type=int,
help='Number of workers for each data loader')
parser.add_argument('--GPU', default='0', type=str, help='GPU index')
#模型迭代一次全部数据集
def train_epoch(epoch, device, model, loss, optimizer, dataloader, summary):
print('迭代第{}次'.format(epoch+1))
model.train()
loss_sum = 0
time_now = time.time()
for step, (data, target) in enumerate(dataloader):
"""pos_weight = 1.0 / (target == 1).sum().item()
neg_weight = 1.0 / (target == 0).sum().item()
print(pos_weight, (target == 1).sum().item())
print(neg_weight, (target == 0).sum().item())
weights = torch.tensor([pos_weight if label == 1 else neg_weight for label in target]).reshape(target.shape[0], 1)
loss = nn.BCEWithLogitsLoss(weight=weights)
loss.to(device)"""
data, target = data.to(device), target.to(device)
output = model(data)
"""print('输出和标签:')
print("output: ", output.detach().reshape(1,-1))
print('target: ', target.detach().reshape(1, -1))"""
#l = loss(output.sigmoid(), target)
l = loss(output, target)
#清空参数梯度信息
optimizer.zero_grad()
l.backward()
#将参数进行更新
optimizer.step()
loss_data = l.item()
loss_sum += loss_data
time_spent = time.time() - time_now
time_now = time.time()
logging.info(
'{}, Epoch : {:3d}, Step : {:3d}, Training Loss : {:.3f}/{:.3f}/{:.3f} '
.format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, step+1, l.item(),
loss_data, time_spent))
summary['loss'] = loss_sum / len(dataloader)
return summary
def run(args, cfg, folder_name='', model_index=0, is_2d=False, load_pretrain = False):
os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
gpus = len(args.GPU)
num_workers = args.num_workers * gpus
batch_size = cfg['batch_size'] * gpus
#初始化模型
model = Net(n_channels=cfg['n_channels'],
n_diff_classes=cfg['n_diff_classes'])
#定义优化器
optimizer = Adam(model.parameters(), lr=cfg['lr'])
scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg['patience'], gamma=cfg['learning_rate_drop'])
#加载之前训练好的模型,从那个基础上进行模型训练
"""base_path = cfg['ckpt_save_path'] + folder_name if len(folder_name) > 1 else cfg['ckpt_save_path']
if is_2d:
ckpt_folder = os.path.join(base_path, cfg['train_csv_file'], cfg['train_csv_file']+'_'+str(model_index))
else:
ckpt_folder = os.path.join(base_path, cfg['train_csv_file'])
ckpt_path = os.listdir(ckpt_folder)
pretrain_ckpt_path = os.path.join(ckpt_folder,ckpt_path[0])
print('checkpoint filename:', pretrain_ckpt_path)
#pretrain_ckpt_path = os.path.join(cfg['ckpt_pretrain_path'], cfg["ckpt_file"])
model_param = torch.load(pretrain_ckpt_path)
model.load_state_dict(model_param['state_dict'], strict=True)"""
if load_pretrain:
pretrain_ckpt_path = cfg['ckpt_pretrain_path']
model_param = torch.load(pretrain_ckpt_path)
model.load_state_dict(model_param['state_dict'], strict=True)
print(f"加载之前训练的参数: {pretrain_ckpt_path}")
#将模型和损失函数放到同一个设备上
model = model.to(device=device)
#定义损失函数
loss = nn.BCELoss()
loss.to(device)
#多gpu进行分布式训练
if 'cuda' == device:
device_ids = list(range(gpus))
model = DataParallel(model, device_ids=device_ids)
#训练数据集
if is_2d:
csv_path = os.path.join(cfg['csv_path'], folder_name) if len(folder_name) > 1 else cfg['csv_path']
"""train_csv_path = os.path.join(cfg['train_data_path'], csv_path,
cfg["train_csv_file"], 'train.csv')"""
train_csv_path = os.path.join(cfg['train_data_path'], csv_path,
cfg["train_csv_file"], cfg["train_csv_file"]+'_'+str(model_index),
'train.csv')
else:
#print('3d')
csv_path = os.path.join(cfg['csv_path'], folder_name) if len(folder_name) > 1 else cfg['csv_path']
train_csv_path = os.path.join(cfg['train_data_path'], csv_path,
cfg["train_csv_file"], 'train.csv')
train_npy_path = os.path.join(cfg['train_data_path'], cfg["npy_folder"])
dataset_train = SubjectDataset(train_npy_path, train_csv_path, is_train=True, is_2d=is_2d, augment=True, permute=True)
#打印训练数据个数
print('train number:', len(dataset_train))
dataloader_train = DataLoader(dataset_train, batch_size=batch_size,
num_workers=num_workers, shuffle=True)
#保存所有的训练损失结果
summary_train_loss = []
summary_train = {'loss': float('inf')}
if is_2d:
base_path = cfg['ckpt_save_path'] + folder_name if len(folder_name) > 1 else cfg['ckpt_save_path']
ckpt_folder = os.path.join(base_path, cfg['train_csv_file'], cfg['train_csv_file']+'_'+str(model_index))
else:
base_path = cfg['ckpt_save_path'] + folder_name if len(folder_name) > 1 else cfg['ckpt_save_path']
ckpt_folder = os.path.join(base_path, cfg['train_csv_file'])
if not os.path.exists(ckpt_folder):
os.makedirs(ckpt_folder)
for epoch in range(cfg['epoch']):
scheduler.step()
logging.info('{}, Epoch : {:3d}, lr : {}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, scheduler.get_last_lr()))
time_now = time.time()
summary_train = train_epoch(epoch=epoch, device=device, model=model, loss=loss,
optimizer=optimizer, dataloader=dataloader_train, summary=summary_train)
time_spent = time.time() - time_now
logging.info(
'{}, Epoch : {:3d}, Summary Training Loss : {:.3f}, '
'Run Time : {:.2f}'
.format(time.strftime('%Y-%m-%d %H:%M:%S'), epoch, summary_train['loss'], time_spent)
)
summary_train_loss.append(summary_train['loss'])
if (epoch + 1) % 200 == 0:
ckpt_path = os.path.join(ckpt_folder, cfg['train_csv_file'] + time.strftime('%Y%m%d-%H%M') + '.ckpt')
torch.save({'epoch': epoch,
'loss':summary_train['loss'],
'state_dict':model.module.state_dict()},
ckpt_path)
ckpt_path = os.path.join(ckpt_folder, cfg['train_csv_file'] + time.strftime('%Y%m%d-%H%M') + '.ckpt')
#保存损失函数结果
torch.save({'epoch': epoch,
'loss':summary_train['loss'],
'state_dict':model.module.state_dict()},
ckpt_path)
print(f"train done, save to: {ckpt_path}")
#将训练过程中的损失变化转换成图片保存下来
img_path = os.path.join(cfg['train_data_path'], cfg['image_path'], cfg['train_csv_file'])
if not os.path.exists(img_path):
os.makedirs(img_path)
result_img_path = os.path.join(img_path, time.strftime('%Y%m%d-%H%M') + '.png')
save_summary_data(summary_trains=summary_train_loss, result_img_path=result_img_path)
print(f"save_summary_data: {result_img_path}")
def main(train_csv_file, folder_name='', load_pretrain = False):
filename = cfg['training_filename']
directory = os.path.dirname(filename)
if not os.path.exists(directory):
os.makedirs(directory)
if not os.path.exists(filename):
with open(filename, 'w') as file:
pass
print(f"文件 '{filename}' 已创建。")
else:
print(f"文件 '{filename}' 已存在。")
logging.basicConfig(level=logging.INFO, filename=cfg['training_filename'], filemode='a')
args = parser.parse_args()
#将同一个二分类的所有批量数据都训练
#自动获取所有的训练数据的批数
cfg['train_csv_file'] = train_csv_file
csv_path = os.path.join(cfg['csv_path'], folder_name)
print(f"cur_path: {cur_path}")
paths = os.path.join(cur_path, cfg['train_data_path'], csv_path, cfg['train_csv_file'], '*')
print(f"paths: {paths}")
model_sum = len(glob.glob(paths))
for i in range(model_sum):
print(f"run: {i}")
run(args, cfg, model_index=i+1, folder_name=folder_name, is_2d=False, load_pretrain = load_pretrain)
def run_signal_3d():
cfg = load_json("/home/lung/ai-project/cls_train/config/train.json")
logging.basicConfig(level=logging.INFO, filename=cfg['training_filename'], filemode='a')
args = parser.parse_args()
run(args, cfg, is_2d=False)
if __name__ == '__main__':
folder_names = ['08']
# base_path = '/home/lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/'
base_path = '/df_lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/'
# train_csv_file = "cls_20241112_2041"
# main(train_csv_file, folder_name="08", load_pretrain = True)
for name in folder_names:
train_path = base_path + name
for train_csv_file in os.listdir(train_path):
if train_csv_file == 'cls_1_2041':
main(train_csv_file, folder_name=name, load_pretrain = True)
\ No newline at end of file
import argparse
import os, sys
import pathlib
current_dir = pathlib.Path(__file__).parent.resolve()
while "cls_train" != os.path.basename(current_dir):
current_dir = current_dir.parent
sys.path.append(current_dir.as_posix())
os.environ["OMP_NUM_THREADS"] = "1"
from cls_utils.log_utils import get_logger
from sklearn.metrics import classification_report
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from pytorch_train.classification_model_2d3d import Net2d3d, Net2d, Net3d, init_modules
from pytorch_train.encoder_cls import S3dClassifier as NetS3d
from pytorch_train.encoder_cls import D2dClassifier as NetD2d
from pytorch_train.resnet import ResNetClassifier as NetResNet3d
from pytorch_train.resnet import Bottleneck
from data.dataset_2d3d import ClassificationDataset2d3d, ClassificationDataset3d, ClassificationDataset2d, custom_collate_fn_2d, custom_collate_fn_3d, custom_collate_fn_2d3d, cls_report_dict_to_string, ClassificationDatasetError2d, ClassificationDatasetError3d, ClassificationDatasetError2d3d, custom_collate_fn_2d_error, custom_collate_fn_3d_error, custom_collate_fn_2d3d_error
from transformers import get_cosine_schedule_with_warmup
from data.dataset_2d3d import ClassificationDatasetS3d, ClassificationDatasetErrorS3d, custom_collate_fn_s3d, custom_collate_fn_s3d_error
from data.dataset_2d3d import ClassificationDatasetResnet3d, ClassificationDatasetErrorResnet3d, custom_collate_fn_resnet3d, custom_collate_fn_resnet3d_error
from data.dataset_2d3d import ClassificationDatasetD2d, ClassificationDatasetErrorD2d, custom_collate_fn_d2d, custom_collate_fn_d2d_error
import traceback
import pandas as pd
from pathlib import Path
from torch.optim import lr_scheduler, Adam
import random
import numpy as np
def set_seed(seed=1004):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def predict_on_single_gpu(net_id, model, dataloader=None, criterion=None, return_info_list=False, return_classification_report=False, return_cls_report_dict=False, threshold=0.5):
"""
模型评测函数
"""
model.eval()
all_preds = []
all_labels = []
all_label_id = []
all_preds_2d = []
all_preds_3d = []
all_labels_2d = []
all_labels_3d = []
all_label_id_2d = []
all_label_id_3d = []
val_info_list = []
epoch = 0
device = next(model.parameters()).device
with torch.no_grad():
for step, batch_data in enumerate(dataloader):
if net_id == "2d3d":
label_id_2d, data_2d, label_id_3d, data_3d, label_2d, label_3d = batch_data
if return_classification_report:
all_label_id_2d.extend(label_id_2d)
all_label_id_3d.extend(label_id_3d)
all_labels_2d.extend(label_2d.numpy())
all_labels_3d.extend(label_3d.numpy())
data_2d = data_2d.to(device)
data_3d = data_3d.to(device)
label_2d = label_2d.to(device)
label_3d = label_3d.to(device)
output_2d, output_3d = model(data_2d, data_3d)
if return_classification_report:
all_preds_2d.extend(output_2d.detach().cpu().numpy())
all_preds_3d.extend(output_3d.detach().cpu().numpy())
else:
label_id, data, label = batch_data
if return_classification_report:
all_label_id.extend(label_id)
all_labels.extend(label.numpy())
data = data.to(device)
label = label.to(device)
output = model(data)
if return_classification_report:
all_preds.extend(output.detach().cpu().numpy())
if return_info_list:
if net_id == "2d3d":
loss_2d = criterion(output_2d, label_2d)
loss_3d = criterion(output_3d, label_3d)
loss = loss_2d.sum() + loss_3d.sum()
else:
loss = criterion(output, label)
val_info_list.append([epoch, step+1, loss.item()])
torch.cuda.empty_cache()
cls_report = None
cls_report_str = ""
cls_report_str_2d = ""
cls_report_str_3d = ""
if return_classification_report:
if net_id == "2d3d":
all_preds_2d = np.array(all_preds_2d)
all_preds_3d = np.array(all_preds_3d)
all_labels_2d = np.array(all_labels_2d)
all_labels_3d = np.array(all_labels_3d)
all_preds_2d = (all_preds_2d > threshold).astype(np.int32)
all_preds_3d = (all_preds_3d > threshold).astype(np.int32)
cls_report_2d = classification_report(
all_labels_2d, all_preds_2d,
labels=[0, 1],
target_names=['negative', 'positive'],
zero_division=0,
output_dict=return_cls_report_dict
)
cls_report_3d = classification_report(
all_labels_3d, all_preds_3d,
labels=[0, 1],
target_names=['negative', 'positive'],
zero_division=0,
output_dict=return_cls_report_dict
)
cls_report_str_2d = f"2d_cls_report: \n{cls_report_dict_to_string(cls_report_2d)}\n"
cls_report_str_3d = f"3d_cls_report: \n{cls_report_dict_to_string(cls_report_3d)}\n"
cls_report_str = cls_report_str_2d + cls_report_str_3d
all_preds_2d = []
all_preds_3d = []
all_labels_2d = []
all_labels_3d = []
else:
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_preds = (all_preds > threshold).astype(np.int32)
cls_report = classification_report(
all_labels, all_preds,
labels=[0, 1],
target_names=['negative', 'positive'],
zero_division=0,
output_dict=return_cls_report_dict
)
cls_report_str = f"cls_report: \n{cls_report_dict_to_string(cls_report)}\n"
all_preds = []
all_labels = []
if net_id == "2d3d":
return all_label_id_2d, all_label_id_3d, all_preds_2d, all_preds_3d, all_labels_2d, all_labels_3d, val_info_list, cls_report_2d, cls_report_3d, cls_report_str
else:
return all_label_id, all_preds, all_labels, val_info_list, cls_report, cls_report_str
def start_train_ddp(model, config, args):
logger = config['logger']
net_id = config['net_id']
train_2d3d_data_2d_csv_file = config['train_2d3d_data_2d_csv_file']
val_2d3d_data_2d_csv_file = config['val_2d3d_data_2d_csv_file']
test_2d3d_data_2d_csv_file = config['test_2d3d_data_2d_csv_file']
train_2d3d_data_3d_csv_file = config['train_2d3d_data_3d_csv_file']
val_2d3d_data_3d_csv_file = config['val_2d3d_data_3d_csv_file']
test_2d3d_data_3d_csv_file = config['test_2d3d_data_3d_csv_file']
train_csv_file = config['train_csv_file']
val_csv_file = config['val_csv_file']
pos_label = config['pos_label']
neg_label = config['neg_label']
train_batch_size = config['train_batch_size']
val_batch_size = config['val_batch_size']
val_metric = config['val_metric']
num_workers = config['num_workers']
device_index = config['device_index']
num_epochs = config['num_epochs']
weight_decay = config['weight_decay']
lr = config['lr']
criterion = config['criterion']
val_interval = config['val_interval']
save_dir = config['save_dir']
threshold = config['threshold']
logger_train_cls_report_flag = config['logger_train_cls_report_flag']
train_on_error_data_flag = config['train_on_error_data_flag']
train_on_error_data_epoch_dict_list = config['train_on_error_data_epoch_dict_list'] if train_on_error_data_flag else []
train_on_error_data_epoch_dict_list = sorted(train_on_error_data_epoch_dict_list, key=lambda x: x["train_epoch"])
current_train_epoch = 0
if len(device_index) > 1:
world_size = torch.distributed.get_world_size()
if len(device_index) > 1 and args.local_rank not in [-1, 0]:
torch.distributed.barrier()
if net_id == "2d":
train_dataset = ClassificationDataset2d(train_csv_file, data_info="train_dataset_2d")
val_dataset = ClassificationDataset2d(val_csv_file, data_info="val_dataset_2d")
custom_collate_fn = custom_collate_fn_2d
train_error_dataset_class = ClassificationDatasetError2d
custom_collate_fn_error = custom_collate_fn_2d_error
elif net_id == "3d":
train_dataset = ClassificationDataset3d(train_csv_file, data_info="train_dataset_3d")
val_dataset = ClassificationDataset3d(val_csv_file, data_info="val_dataset_3d")
custom_collate_fn = custom_collate_fn_3d
train_error_dataset_class = ClassificationDatasetError3d
custom_collate_fn_error = custom_collate_fn_3d_error
elif net_id == "2d3d":
train_dataset = ClassificationDataset2d3d(train_2d3d_data_2d_csv_file, train_2d3d_data_3d_csv_file, data_info="train_dataset_2d3d")
val_dataset = ClassificationDataset2d3d(val_2d3d_data_2d_csv_file, val_2d3d_data_3d_csv_file, data_info="val_dataset_2d3d")
custom_collate_fn = custom_collate_fn_2d3d
train_error_dataset_class = ClassificationDatasetError2d3d
custom_collate_fn_error = custom_collate_fn_2d3d_error
elif net_id == "s3d":
train_dataset = ClassificationDatasetS3d(train_csv_file, data_info="train_dataset_S3d")
val_dataset = ClassificationDatasetS3d(val_csv_file, data_info="val_dataset_S3d")
custom_collate_fn = custom_collate_fn_s3d
train_error_dataset_class = ClassificationDatasetErrorS3d
custom_collate_fn_error = custom_collate_fn_s3d_error
elif net_id == "resnet3d":
train_dataset = ClassificationDatasetResnet3d(train_csv_file, data_info="train_dataset_resnet3d")
val_dataset = ClassificationDatasetResnet3d(val_csv_file, data_info="val_dataset_resnet3d")
custom_collate_fn = custom_collate_fn_resnet3d
train_error_dataset_class = ClassificationDatasetErrorResnet3d
custom_collate_fn_error = custom_collate_fn_resnet3d_error
elif net_id == "d2d":
train_dataset = ClassificationDatasetD2d(train_csv_file, data_info="train_dataset_d2d")
val_dataset = ClassificationDatasetD2d(val_csv_file, data_info="val_dataset_d2d")
custom_collate_fn = custom_collate_fn_d2d
train_error_dataset_class = ClassificationDatasetErrorD2d
custom_collate_fn_error = custom_collate_fn_d2d_error
else:
raise ValueError(f"net_id {net_id} not supported")
if len(device_index) > 1 and args.local_rank == 0:
torch.distributed.barrier()
if len(device_index) > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, rank=args.local_rank, num_replicas=world_size)
val_sampler = DistributedSampler(val_dataset, shuffle=False, rank=args.local_rank, num_replicas=world_size)
train_data_loader = DataLoader(
train_dataset,
batch_size=train_batch_size // len(device_index),
drop_last=False,
shuffle=False,
num_workers=num_workers,
sampler=train_sampler,
collate_fn=custom_collate_fn
)
val_data_loader = DataLoader(
val_dataset,
batch_size=val_batch_size // len(device_index),
drop_last=False,
shuffle=False,
num_workers=num_workers,
sampler=val_sampler,
collate_fn=custom_collate_fn
)
else:
train_data_loader = DataLoader(
train_dataset,
batch_size=train_batch_size,
drop_last=False,
shuffle=True,
num_workers=num_workers,
collate_fn=custom_collate_fn
)
val_data_loader = DataLoader(
val_dataset,
batch_size=val_batch_size,
drop_last=False,
shuffle=False,
num_workers=num_workers,
collate_fn=custom_collate_fn
)
logger.info(f"start_train_ddp, net_id: {net_id}, device_index: {device_index}")
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train data total batch: {len(train_data_loader)}, val data total batch: {len(val_data_loader)}")
logger.info(f"local rank: {args.local_rank}, world size: {world_size}")
else:
logger.info(f"local rank: {device_index[0]}, train data total batch: {len(train_data_loader)}, val data total batch: {len(val_data_loader)}")
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train with DistributedDataParallel")
model = DistributedDataParallel(
model.cuda(),
device_ids=[args.local_rank],
output_device=args.local_rank
)
else:
logger.info(f"local rank: {device_index[0]}, train with single gpu")
model = model.to(device_index[0])
if args.use_zero:
optimizer = ZeroRedundancyOptimizer(
model.parameters(),
optimizer_class=torch.optim.Adam,
lr=lr,
weight_decay=weight_decay
)
else:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train with optimizer: {optimizer}")
else:
logger.info(f"local rank: {device_index[0]}, train with optimizer: {optimizer}")
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=10,
num_training_steps=len(train_data_loader) * num_epochs
)
train_info_list = []
val_info_list = []
best_val_metric_score = 0
current_step = 0
if val_interval is None:
val_interval = len(train_data_loader)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, before train batch count: {len(train_data_loader)}, val interval batch count: {val_interval}")
else:
logger.info(f"local rank: {device_index[0]}, before train batch count: {len(train_data_loader)}, val interval batch count: {val_interval}")
train_all_preds = []
train_all_labels = []
train_all_label_id = []
train_all_label_id_2d = []
train_all_label_id_3d = []
train_all_preds_2d = []
train_all_preds_3d = []
train_all_labels_2d = []
train_all_labels_3d = []
error_data_dict = {}
error_data_2d_dict = {}
error_data_3d_dict = {}
print_init_params(model, string=f"start_train_ddp, train_{net_id}", logger=logger)
for epoch in range(num_epochs):
model.train()
if len(device_index) > 1:
train_sampler.set_epoch(epoch)
current_train_epoch += 1
for step, batch_data in enumerate(train_data_loader):
current_step += 1
if net_id == "2d3d":
label_id_2d, data_2d, label_2d, z_index_2d, rotate_count_2d, label_id_3d, data_3d, label_3d, z_index_3d, rotate_count_3d = batch_data
if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]):
train_all_label_id_2d.extend(label_id_2d)
train_all_label_id_3d.extend(label_id_3d)
train_all_labels_2d.append(label_2d.cpu())
train_all_labels_3d.append(label_3d.cpu())
if len(device_index) > 1:
data_2d = data_2d.cuda()
data_3d = data_3d.cuda()
label_2d = label_2d.cuda()
label_3d = label_3d.cuda()
else:
data_2d = data_2d.to(device_index[0])
data_3d = data_3d.to(device_index[0])
label_2d = label_2d.to(device_index[0])
label_3d = label_3d.to(device_index[0])
y_pred_2d, y_pred_3d = model(data_2d, data_3d)
if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]):
train_all_preds_2d.append(y_pred_2d.detach().cpu())
train_all_preds_3d.append(y_pred_3d.detach().cpu())
else:
label_id, data, label, z_index, rotate_count = batch_data
if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]):
train_all_label_id.extend(label_id)
train_all_labels.append(label.cpu())
if len(device_index) > 1:
data = data.cuda()
label = label.cuda()
else:
data = data.to(device_index[0])
label = label.to(device_index[0])
y_pred = model(data)
if logger_train_cls_report_flag or (train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]):
train_all_preds.append(y_pred.detach().cpu())
if net_id == "2d3d":
loss = criterion(y_pred_2d, label_2d).sum() + criterion(y_pred_3d, label_3d).sum()
else:
loss = criterion(y_pred, label).sum()
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}")
else:
logger.info(f"local rank: {device_index[0]}, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}")
optimizer.zero_grad()
loss.backward()
optimizer.step()
if len(device_index) > 1:
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
lr_scheduler.step()
if net_id == "2d3d" and current_step > 0 and current_step % 1000 == 0:
train_step_pt_file = f"train_epoch_{epoch+1}_step_{current_step}_net_{net_id}.pt"
save_train_step_pt_file = os.path.join(save_dir, train_step_pt_file)
if len(device_index) > 1:
torch.save(model.module.state_dict(), save_train_step_pt_file)
logger.info(f"local rank: {args.local_rank}, save train epoch step pt file to {save_train_step_pt_file}")
else:
torch.save(model.state_dict(), save_train_step_pt_file)
logger.info(f"local rank: {device_index[0]}, save train epoch step pt file to {save_train_step_pt_file}")
# 评测
if current_step % val_interval == 0 and ((len(device_index) > 1 and args.local_rank == 0) or len(device_index) == 1):
if net_id == "2d3d":
_, _, _, _, _, _, current_val_info_list, current_val_classification_report_2d, current_val_classification_report_3d, current_val_classification_report_str = predict_on_single_gpu(
net_id,
model,
dataloader=val_data_loader,
criterion=criterion,
return_info_list=True,
return_classification_report=True,
return_cls_report_dict=True,
threshold=threshold
)
else:
_, _, _, current_val_info_list, current_val_classification_report, current_val_classification_report_str = predict_on_single_gpu(
net_id,
model,
dataloader=val_data_loader,
criterion=criterion,
return_info_list=True,
return_classification_report=True,
return_cls_report_dict=True,
threshold=threshold
)
val_info_list += current_val_info_list
if len(device_index) > 1:
logger.info(f"current_local_rank: {args.local_rank}, current_device_index: {device_index[args.local_rank]},\nepoch: {epoch+1}, step: {step+1}, loss: {current_val_info_list[-1][2]}\n 训练评测--验证集: classification report:\n {current_val_classification_report_str}")
else:
logger.info(f"current_local_rank: {device_index[0]}, current_device_index: {device_index[0]},\nepoch: {epoch+1}, step: {step+1}, loss: {current_val_info_list[-1][2]}\n 训练评测--验证集: classification report:\n {current_val_classification_report_str}")
if net_id == "2d3d":
current_score_list = [
current_val_classification_report_2d[pos_label][val_metric],
current_val_classification_report_3d[pos_label][val_metric],
current_val_classification_report_2d[neg_label][val_metric],
current_val_classification_report_3d[neg_label][val_metric]
]
current_val_metric_score = sum(current_score_list) / len(current_score_list)
else:
current_score_list = [
current_val_classification_report[pos_label][val_metric],
current_val_classification_report[neg_label][val_metric]
]
current_val_metric_score = sum(current_score_list) / len(current_score_list)
if current_val_metric_score > best_val_metric_score + 0.000000001:
best_val_metric_score = current_val_metric_score
save_path = os.path.join(save_dir, f"best_epoch_{epoch+1}_step_{step+1}_score_{best_val_metric_score:.4f}.pth")
if len(device_index) > 1:
torch.save(model.module.state_dict(), save_path)
else:
torch.save(model.state_dict(), save_path)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, best val {val_metric} score: {best_val_metric_score}, save model to {save_path}")
else:
logger.info(f"local rank: {device_index[0]}, best val {val_metric} score: {best_val_metric_score}, save model to {save_path}")
# 评测训练集 - epoch 结束时
if logger_train_cls_report_flag:
train_cls_report = None
if net_id == "2d3d":
cls_train_all_label_id_2d = train_all_label_id_2d[:]
cls_train_all_label_id_3d = train_all_label_id_3d[:]
cls_train_all_preds_2d = torch.cat(train_all_preds_2d, dim=0)
cls_train_all_preds_3d = torch.cat(train_all_preds_3d, dim=0)
cls_train_all_labels_2d = torch.cat(train_all_labels_2d, dim=0)
cls_train_all_labels_3d = torch.cat(train_all_labels_3d, dim=0)
cls_train_all_preds_2d_binary = (cls_train_all_preds_2d > threshold).int()
cls_train_all_preds_3d_binary = (cls_train_all_preds_3d > threshold).int()
# positive_indices_2d = torch.where(cls_train_all_labels_2d == 1)
# negative_indices_2d = torch.where(cls_train_all_labels_2d == 0)
# positive_indices_3d = torch.where(cls_train_all_labels_3d == 1)
# negative_indices_3d = torch.where(cls_train_all_labels_3d == 0)
positive_error_2d_indices = torch.where((cls_train_all_labels_2d == 1) & (cls_train_all_preds_2d_binary == 0))
negative_error_2d_indices = torch.where((cls_train_all_labels_2d == 0) & (cls_train_all_preds_2d_binary == 1))
positive_error_3d_indices = torch.where((cls_train_all_labels_3d == 1) & (cls_train_all_preds_3d_binary == 0))
negative_error_3d_indices = torch.where((cls_train_all_labels_3d == 0) & (cls_train_all_preds_3d_binary == 1))
def get_error_samples(indices, label_id, labels, preds):
error_label_ids = [label_id[i] for i in indices[0]]
error_labels = labels[indices]
error_preds = preds[indices]
return error_label_ids, error_labels, error_preds
error_label_id_2d_pos, error_labels_2d_pos, error_preds_2d_pos = get_error_samples(positive_error_2d_indices, cls_train_all_label_id_2d, cls_train_all_labels_2d, cls_train_all_preds_2d)
error_label_id_2d_neg, error_labels_2d_neg, error_preds_2d_neg = get_error_samples(negative_error_2d_indices, cls_train_all_label_id_2d, cls_train_all_labels_2d, cls_train_all_preds_2d)
error_label_id_3d_pos, error_labels_3d_pos, error_preds_3d_pos = get_error_samples(positive_error_3d_indices, cls_train_all_label_id_3d, cls_train_all_labels_3d, cls_train_all_preds_3d)
error_label_id_3d_neg, error_labels_3d_neg, error_preds_3d_neg = get_error_samples(negative_error_3d_indices, cls_train_all_label_id_3d, cls_train_all_labels_3d, cls_train_all_preds_3d)
train_cls_report_2d = classification_report(cls_train_all_labels_2d.flatten().cpu().numpy(), cls_train_all_preds_2d_binary.flatten().cpu().numpy(), labels=[0, 1], target_names=['negative', 'positive'], zero_division=0)
train_cls_report_3d = classification_report(cls_train_all_labels_3d.flatten().cpu().numpy(), cls_train_all_preds_3d_binary.flatten().cpu().numpy(), labels=[0, 1], target_names=['negative', 'positive'], zero_division=0)
train_cls_report = f"2d_cls_report: \n{train_cls_report_2d}\n3d_cls_report: \n{train_cls_report_3d}"
cls_error_info_str = ""
cls_error_info_str += f"net_id_{net_id}_2d分类: 预测错误的正样本信息, 个数: {len(error_labels_2d_pos)}:\n"
for label_id, label, pred in zip(error_label_id_2d_pos, error_labels_2d_pos, error_preds_2d_pos):
cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n"
cls_error_info_str += f"net_id_{net_id}_2d分类: 预测错误的负样本信息, 个数: {len(error_labels_2d_neg)}:\n"
for label_id, label, pred in zip(error_label_id_2d_neg, error_labels_2d_neg, error_preds_2d_neg):
cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n"
cls_error_info_str += f"net_id_{net_id}_3d分类: 预测错误的正样本信息, 个数: {len(error_labels_3d_pos)}:\n"
for label_id, label, pred in zip(error_label_id_3d_pos, error_labels_3d_pos, error_preds_3d_pos):
cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n"
cls_error_info_str += f"net_id_{net_id}_3d分类: 预测错误的负样本信息, 个数: {len(error_labels_3d_neg)}:\n"
for label_id, label, pred in zip(error_label_id_3d_neg, error_labels_3d_neg, error_preds_3d_neg):
cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n"
del cls_train_all_label_id_2d, cls_train_all_label_id_3d
del cls_train_all_preds_2d, cls_train_all_preds_3d, cls_train_all_labels_2d, cls_train_all_labels_3d
del cls_train_all_preds_2d_binary, cls_train_all_preds_3d_binary
# del positive_indices_2d, negative_indices_2d, positive_indices_3d, negative_indices_3d
del positive_error_2d_indices, negative_error_2d_indices, positive_error_3d_indices, negative_error_3d_indices
del error_label_id_2d_pos, error_labels_2d_pos, error_preds_2d_pos
del error_label_id_2d_neg, error_labels_2d_neg, error_preds_2d_neg
del error_label_id_3d_pos, error_labels_3d_pos, error_preds_3d_pos
del error_label_id_3d_neg, error_labels_3d_neg, error_preds_3d_neg
del train_cls_report_2d, train_cls_report_3d
torch.cuda.empty_cache()
else:
cls_train_all_label_id = train_all_label_id[:]
cls_train_all_labels = torch.cat(train_all_labels, dim=0)
cls_train_all_preds = torch.cat(train_all_preds, dim=0)
cls_train_all_preds_binary = (cls_train_all_preds > threshold).int()
# positive_indices = torch.where(cls_train_all_labels == 1)
# negative_indices = torch.where(cls_train_all_labels == 0)
positive_error_indices = torch.where((cls_train_all_labels == 1) & (cls_train_all_preds_binary == 0))
negative_error_indices = torch.where((cls_train_all_labels == 0) & (cls_train_all_preds_binary == 1))
def get_error_samples(indices, label_id, labels, preds):
error_label_ids = [label_id[i] for i in indices[0]]
error_labels = labels[indices]
error_preds = preds[indices]
return error_label_ids, error_labels, error_preds
error_label_id_pos, error_labels_pos, error_preds_pos = get_error_samples(positive_error_indices, cls_train_all_label_id, cls_train_all_labels, cls_train_all_preds)
error_label_id_neg, error_labels_neg, error_preds_neg = get_error_samples(negative_error_indices, cls_train_all_label_id, cls_train_all_labels, cls_train_all_preds)
train_cls_report = classification_report(cls_train_all_labels.flatten().cpu().numpy(), cls_train_all_preds_binary.flatten().cpu().numpy(), labels=[0, 1], target_names=['negative', 'positive'], zero_division=0)
train_cls_report = f"cls_report: \n{train_cls_report}"
cls_error_info_str = ""
cls_error_info_str += f"net_id_{net_id}_分类: 预测错误的正样本信息, 个数: {len(error_labels_pos)}:\n"
for label_id, label, pred in zip(error_label_id_pos, error_labels_pos, error_preds_pos):
cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n"
cls_error_info_str += f"net_id_{net_id}_分类: 预测错误的负样本信息, 个数: {len(error_labels_neg)}:\n"
for label_id, label, pred in zip(error_label_id_neg, error_labels_neg, error_preds_neg):
cls_error_info_str += f"label_id: {label_id}, label: {label}, pred: {pred}\n"
del cls_train_all_label_id, cls_train_all_labels, cls_train_all_preds
del cls_train_all_preds_binary, positive_error_indices, negative_error_indices
del error_label_id_pos, error_labels_pos, error_preds_pos
del error_label_id_neg, error_labels_neg, error_preds_neg
torch.cuda.empty_cache()
train_info_list.append([epoch+1, step+1, loss.item(), str(train_cls_report)])
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, 评测训练集, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}\n train classification report:\n {train_cls_report}")
logger.info(f"local rank: {args.local_rank}, 评测训练集, train epoch: {epoch+1}, 训练集分类错误的信息\n{cls_error_info_str}")
else:
logger.info(f"local rank: {device_index[0]}, 评测训练集, train epoch: {epoch+1}, step: {step+1}, loss: {loss.item()}\n train classification report:\n {train_cls_report}")
logger.info(f"local rank: {device_index[0]}, 评测训练集, train epoch: {epoch+1}, 训练集分类错误的信息\n{cls_error_info_str}")
# 保存 - epoch 结束时
if len(device_index) > 1:
torch.distributed.barrier()
if (len(device_index) > 1 and args.local_rank == 0) or len(device_index) == 1:
train_epoch_pt_file = f"train_epoch_{epoch+1}_net_{net_id}.pt"
save_train_epoch_pt_file = os.path.join(save_dir, train_epoch_pt_file)
if len(device_index) > 1:
torch.save(model.module.state_dict(), save_train_epoch_pt_file)
else:
torch.save(model.state_dict(), save_train_epoch_pt_file)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train epoch: {epoch+1}, save model to {save_train_epoch_pt_file}")
else:
logger.info(f"local rank: {device_index[0]}, train epoch: {epoch+1}, save model to {save_train_epoch_pt_file}")
# 训练集错误的数据,单独训练
if train_on_error_data_flag and len(train_on_error_data_epoch_dict_list) > 0 and current_train_epoch == train_on_error_data_epoch_dict_list[0]["train_epoch"]:
current_train_on_error_data_epoch_dict = train_on_error_data_epoch_dict_list.pop(0)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train_on_error_data_epoch_dict_list: {len(train_on_error_data_epoch_dict_list)}")
logger.info(f"local rank: {args.local_rank}, train on error data epochs: {current_train_on_error_data_epoch_dict['train_epoch']}, train epoch: {epoch+1}")
else:
logger.info(f"local rank: {device_index[0]}, train_on_error_data_epoch_dict_list: {len(train_on_error_data_epoch_dict_list)}")
logger.info(f"local rank: {device_index[0]}, train on error data epochs: {current_train_on_error_data_epoch_dict['train_epoch']}, train epoch: {epoch+1}")
if net_id == "2d3d":
cls_train_all_label_id_2d = train_all_label_id_2d[:]
cls_train_all_label_id_3d = train_all_label_id_3d[:]
cls_train_all_preds_2d = torch.cat(train_all_preds_2d, dim=0)
cls_train_all_preds_3d = torch.cat(train_all_preds_3d, dim=0)
cls_train_all_labels_2d = torch.cat(train_all_labels_2d, dim=0)
cls_train_all_labels_3d = torch.cat(train_all_labels_3d, dim=0)
# 修改成正样本预测概率小于0.2,负样本预测概率大于0.8
current_positive_threshold = current_train_on_error_data_epoch_dict["positive_threshold"]
current_negative_threshold = current_train_on_error_data_epoch_dict["negative_threshold"]
positive_error_2d_indices = torch.where((cls_train_all_labels_2d == 1) & (cls_train_all_preds_2d < current_positive_threshold))
negative_error_2d_indices = torch.where((cls_train_all_labels_2d == 0) & (cls_train_all_preds_2d > current_negative_threshold))
positive_error_3d_indices = torch.where((cls_train_all_labels_3d == 1) & (cls_train_all_preds_3d < current_positive_threshold))
negative_error_3d_indices = torch.where((cls_train_all_labels_3d == 0) & (cls_train_all_preds_3d > current_negative_threshold))
log_train_error_data_str = ""
log_train_error_data_str += f"train on error data, net_id_{net_id}_2d分类: 预测错误的正样本信息, 个数: {len(positive_error_2d_indices[0])}:\n"
for idx_index in positive_error_2d_indices[0]:
log_train_error_data_str += f"label_id: {cls_train_all_label_id_2d[idx_index]}, label: {cls_train_all_labels_2d[idx_index]}, pred: {cls_train_all_preds_2d[idx_index]}\n"
log_train_error_data_str += f"train on error data, net_id_{net_id}_2d分类: 预测错误的负样本信息, 个数: {len(negative_error_2d_indices[0])}:\n"
for idx_index in negative_error_2d_indices[0]:
log_train_error_data_str += f"label_id: {cls_train_all_label_id_2d[idx_index]}, label: {cls_train_all_labels_2d[idx_index]}, pred: {cls_train_all_preds_2d[idx_index]}\n"
log_train_error_data_str += f"train on error data, net_id_{net_id}_3d分类: 预测错误的正样本信息, 个数: {len(positive_error_3d_indices[0])}:\n"
for idx_index in positive_error_3d_indices[0]:
log_train_error_data_str += f"label_id: {cls_train_all_label_id_3d[idx_index]}, label: {cls_train_all_labels_3d[idx_index]}, pred: {cls_train_all_preds_3d[idx_index]}\n"
log_train_error_data_str += f"train on error data, net_id_{net_id}_3d分类: 预测错误的负样本信息, 个数: {len(negative_error_3d_indices[0])}:\n"
for idx_index in negative_error_3d_indices[0]:
log_train_error_data_str += f"label_id: {cls_train_all_label_id_3d[idx_index]}, label: {cls_train_all_labels_3d[idx_index]}, pred: {cls_train_all_preds_3d[idx_index]}\n"
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}")
error_2d_indices = torch.cat((positive_error_2d_indices[0], negative_error_2d_indices[0]))
error_3d_indices = torch.cat((positive_error_3d_indices[0], negative_error_3d_indices[0]))
error_data_2d_label_id_list = [cls_train_all_label_id_2d[i] for i in error_2d_indices.cpu().numpy()]
error_data_3d_label_id_list = [cls_train_all_label_id_3d[i] for i in error_3d_indices.cpu().numpy()]
if len(error_data_2d_label_id_list) == 0 and len(error_data_3d_label_id_list) == 0:
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}")
del cls_train_all_label_id_2d
del cls_train_all_label_id_3d
del cls_train_all_preds_2d
del cls_train_all_preds_3d
del cls_train_all_labels_2d
del cls_train_all_labels_3d
del current_positive_threshold
del current_negative_threshold
del positive_error_2d_indices
del negative_error_2d_indices
del positive_error_3d_indices
del negative_error_3d_indices
del log_train_error_data_str
del error_2d_indices
del error_3d_indices
del error_data_2d_label_id_list
del error_data_3d_label_id_list
torch.cuda.empty_cache()
continue
current_train_batch_size = current_train_on_error_data_epoch_dict["error_train_batch_size"]
if len(device_index) > 1:
error_file = f"train_error_data_local_rank_{args.local_rank}_current_train_epoch_{current_train_on_error_data_epoch_dict['train_epoch']}_[{num_epochs}]_train_error_data_epoch_{current_train_on_error_data_epoch_dict['error_train_epoch']}_net_{net_id}.csv"
else:
error_file = f"train_error_data_local_rank_{device_index[0]}_current_train_epoch_{current_train_on_error_data_epoch_dict['train_epoch']}_[{num_epochs}]_train_error_data_epoch_{current_train_on_error_data_epoch_dict['error_train_epoch']}_net_{net_id}.csv"
train_error_csv_file = os.path.join(save_dir, error_file)
error_data_2d_preds_list = [cls_train_all_preds_2d[i].cpu().numpy() for i in error_2d_indices.cpu().numpy()]
error_data_3d_preds_list = [cls_train_all_preds_3d[i].cpu().numpy() for i in error_3d_indices.cpu().numpy()]
error_data_2d_labels_list = [cls_train_all_labels_2d[i].cpu().numpy() for i in error_2d_indices.cpu().numpy()]
error_data_3d_labels_list = [cls_train_all_labels_3d[i].cpu().numpy() for i in error_3d_indices.cpu().numpy()]
if len(device_index) > 1:
print(f"local rank: {args.local_rank}, train on error data\nerror_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}")
else:
print(f"local rank: {device_index[0]}, train on error data\nerror_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}")
train_error_dataset = train_error_dataset_class(
csv_file=train_csv_file,
label_id_2d_list=error_data_2d_label_id_list,
label_id_3d_list=error_data_3d_label_id_list,
error_file=train_error_csv_file,
preds_2d=error_data_2d_preds_list,
preds_3d=error_data_3d_preds_list,
labels_2d=error_data_2d_labels_list,
labels_3d=error_data_3d_labels_list
)
train_error_dataloader = DataLoader(
train_error_dataset,
batch_size=current_train_batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn = custom_collate_fn_error
)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}")
logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}")
logger.info(f"local rank: {args.local_rank}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_2d_label_id_list: {len(error_data_2d_label_id_list)}\nerror_data_3d_label_id_list: {len(error_data_3d_label_id_list)}")
logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_2d_label_id_list: {error_data_2d_label_id_list}\nerror_data_3d_label_id_list: {error_data_3d_label_id_list}")
logger.info(f"local rank: {device_index[0]}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}")
current_train_on_error_epochs = current_train_on_error_data_epoch_dict["error_train_epoch"]
for idx_train_error_data_epoch in range(current_train_on_error_epochs):
for idx_batch in train_error_dataloader:
idx_train_error_data_2d, idx_train_error_data_3d, idx_train_error_label_2d, idx_train_error_label_3d = idx_batch
if len(device_index) > 1:
idx_train_error_data_2d = idx_train_error_data_2d.cuda()
idx_train_error_data_3d = idx_train_error_data_3d.cuda()
idx_train_error_label_2d = idx_train_error_label_2d.cuda()
idx_train_error_label_3d = idx_train_error_label_3d.cuda()
else:
idx_train_error_data_2d = idx_train_error_data_2d.to(device_index[0])
idx_train_error_data_3d = idx_train_error_data_3d.to(device_index[0])
idx_train_error_label_2d = idx_train_error_label_2d.to(device_index[0])
idx_train_error_label_3d = idx_train_error_label_3d.to(device_index[0])
idx_train_error_y_pred_2d, idx_train_error_y_pred_3d = model(idx_train_error_data_2d, idx_train_error_data_3d)
idx_train_error_loss = criterion(idx_train_error_y_pred_2d, idx_train_error_label_2d) + criterion(idx_train_error_y_pred_3d, idx_train_error_label_3d)
optimizer.zero_grad()
idx_train_error_loss.backward()
optimizer.step()
if len(device_index) > 1:
torch.distributed.all_reduce(idx_train_error_loss, op=torch.distributed.ReduceOp.AVG)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}")
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, 错误样本单独训练结束")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, 错误样本单独训练结束")
del cls_train_all_label_id_2d
del cls_train_all_label_id_3d
del cls_train_all_preds_2d
del cls_train_all_preds_3d
del cls_train_all_labels_2d
del cls_train_all_labels_3d
del current_positive_threshold
del current_negative_threshold
del positive_error_2d_indices
del negative_error_2d_indices
del positive_error_3d_indices
del negative_error_3d_indices
del log_train_error_data_str
del error_2d_indices
del error_3d_indices
del error_data_2d_label_id_list
del error_data_3d_label_id_list
del current_train_batch_size
del error_file
del train_error_csv_file
del error_data_2d_preds_list
del error_data_3d_preds_list
del error_data_2d_labels_list
del error_data_3d_labels_list
del train_error_dataset
del train_error_dataloader
del current_train_on_error_epochs
del idx_train_error_data_2d
del idx_train_error_data_3d
del idx_train_error_label_2d
del idx_train_error_label_3d
del idx_train_error_y_pred_2d
del idx_train_error_y_pred_3d
del idx_train_error_loss
torch.cuda.empty_cache()
else:
cls_train_all_label_id = train_all_label_id[:]
cls_train_all_preds = torch.cat(train_all_preds, dim=0)
cls_train_all_labels = torch.cat(train_all_labels, dim=0)
# 修改成正样本预测概率小于0.2,负样本预测概率大于0.8
current_positive_threshold = current_train_on_error_data_epoch_dict["positive_threshold"]
current_negative_threshold = current_train_on_error_data_epoch_dict["negative_threshold"]
positive_error_indices = torch.where((cls_train_all_labels == 1) & (cls_train_all_preds < current_positive_threshold))
negative_error_indices = torch.where((cls_train_all_labels == 0) & (cls_train_all_preds > current_negative_threshold))
log_train_error_data_str = ""
log_train_error_data_str += f"train on error data, net_id_{net_id}_分类: 预测错误的正样本信息, 个数: {len(positive_error_indices[0])}:\n"
for idx_index in positive_error_indices[0]:
log_train_error_data_str += f"label_id: {cls_train_all_label_id[idx_index]}, label: {cls_train_all_labels[idx_index]}, pred: {cls_train_all_preds[idx_index]}\n"
log_train_error_data_str += f"train on error data, net_id_{net_id}_分类: 预测错误的负样本信息, 个数: {len(negative_error_indices[0])}:\n"
for idx_index in negative_error_indices[0]:
log_train_error_data_str += f"label_id: {cls_train_all_label_id[idx_index]}, label: {cls_train_all_labels[idx_index]}, pred: {cls_train_all_preds[idx_index]}\n"
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epochs: {current_train_on_error_data_epoch_dict['train_epoch']}\ntrain epoch: {epoch+1}, 训练集错误样本单独训练,训练集预测错误的信息:\n{log_train_error_data_str}")
error_indices = torch.cat((positive_error_indices[0], negative_error_indices[0]))
error_data_label_id_list = [cls_train_all_label_id[i] for i in error_indices.cpu().numpy()]
if len(error_data_label_id_list) == 0:
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_label_id_list: {error_data_label_id_list}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, 预测错误的样本为0, 跳过错误样本单独训练, error_data_label_id_list: {error_data_label_id_list}")
del cls_train_all_label_id
del cls_train_all_preds
del cls_train_all_labels
del current_positive_threshold
del current_negative_threshold
del positive_error_indices
del negative_error_indices
del log_train_error_data_str
del error_indices
del error_data_label_id_list
torch.cuda.empty_cache()
continue
current_train_batch_size = current_train_on_error_data_epoch_dict["error_train_batch_size"]
error_file = f"train_error_data_current_train_epoch_{current_train_on_error_data_epoch_dict['train_epoch']}_[{num_epochs}]_train_error_data_epoch_{current_train_on_error_data_epoch_dict['error_train_epoch']}_net_{net_id}.csv"
train_error_csv_file = os.path.join(save_dir, error_file)
error_data_preds_list = [cls_train_all_preds[i].cpu().numpy() for i in error_indices.cpu().numpy()]
error_data_labels_list = [cls_train_all_labels[i].cpu().numpy() for i in error_indices.cpu().numpy()]
train_error_dataset = None
if net_id == "2d":
train_error_dataset = train_error_dataset_class(
csv_file=train_csv_file,
label_id_2d_list=error_data_label_id_list,
error_file=train_error_csv_file,
preds_2d=error_data_preds_list,
labels_2d=error_data_labels_list
)
elif net_id == "3d":
train_error_dataset = train_error_dataset_class(
csv_file=train_csv_file,
label_id_3d_list=error_data_label_id_list,
error_file=train_error_csv_file,
preds_3d=error_data_preds_list,
labels_3d=error_data_labels_list
)
elif net_id == "s3d":
train_error_dataset = train_error_dataset_class(
csv_file=train_csv_file,
label_id_3d_list=error_data_label_id_list,
error_file=train_error_csv_file,
preds_3d=error_data_preds_list,
labels_3d=error_data_labels_list
)
elif net_id == "resnet3d":
train_error_dataset = train_error_dataset_class(
csv_file=train_csv_file,
label_id_3d_list=error_data_label_id_list,
error_file=train_error_csv_file,
preds_3d=error_data_preds_list,
labels_3d=error_data_labels_list
)
elif net_id == "d2d":
train_error_dataset = train_error_dataset_class(
csv_file=train_csv_file,
label_id_2d_list=error_data_label_id_list,
error_file=train_error_csv_file,
preds_2d=error_data_preds_list,
labels_2d=error_data_labels_list
)
if train_error_dataset is None:
if len(device_index) > 1:
raise ValueError(f"local rank: {args.local_rank}, train on error data, train_error_dataset is None, net_id: {net_id}")
else:
raise ValueError(f"local rank: {device_index[0]}, train on error data, train_error_dataset is None, net_id: {net_id}")
train_error_dataloader = DataLoader(
train_error_dataset,
batch_size=current_train_batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn = custom_collate_fn_error
)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_label_id_list: {len(error_data_label_id_list)}")
logger.info(f"local rank: {args.local_rank}, train on error data, 原始训练数据, error_data_label_id_list: {error_data_label_id_list}")
logger.info(f"local rank: {args.local_rank}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_label_id_list: {len(error_data_label_id_list)}")
logger.info(f"local rank: {device_index[0]}, train on error data, 原始训练数据, error_data_label_id_list: {error_data_label_id_list}")
logger.info(f"local rank: {device_index[0]}, train on error data, 处理训练数据, train_error_dataloader: {len(train_error_dataloader)}")
# 错误样本单独训练
current_train_on_error_epochs = current_train_on_error_data_epoch_dict["error_train_epoch"]
for idx_train_error_data_epoch in range(current_train_on_error_epochs):
for idx_batch in train_error_dataloader:
idx_train_error_data, idx_train_error_label = idx_batch
if len(device_index) > 1:
idx_train_error_data = idx_train_error_data.cuda()
idx_train_error_label = idx_train_error_label.cuda()
else:
idx_train_error_data = idx_train_error_data.to(device_index[0])
idx_train_error_label = idx_train_error_label.to(device_index[0])
idx_train_error_y_pred = model(idx_train_error_data)
idx_train_error_loss = criterion(idx_train_error_y_pred, idx_train_error_label)
optimizer.zero_grad()
idx_train_error_loss.backward()
optimizer.step()
if len(device_index) > 1:
torch.distributed.all_reduce(idx_train_error_loss, op=torch.distributed.ReduceOp.AVG)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, train_error_data_epoch: {idx_train_error_data_epoch+1}, train_error_epochs: {current_train_on_error_epochs}, current_train_epoch: {epoch+1}, train_epoch: {num_epochs}, loss: {idx_train_error_loss.item()}")
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train on error data, 错误样本单独训练结束")
else:
logger.info(f"local rank: {device_index[0]}, train on error data, 错误样本单独训练结束")
del cls_train_all_label_id
del cls_train_all_preds
del cls_train_all_labels
del current_positive_threshold
del current_negative_threshold
del positive_error_indices
del negative_error_indices
del log_train_error_data_str
del error_indices
del error_data_label_id_list
del current_train_batch_size
del error_file
del train_error_csv_file
del error_data_preds_list
del error_data_labels_list
del train_error_dataset
del train_error_dataloader
del current_train_on_error_epochs
del idx_train_error_data
del idx_train_error_label
del idx_train_error_y_pred
del idx_train_error_loss
torch.cuda.empty_cache()
train_all_preds = []
train_all_labels = []
train_all_label_id = []
train_all_label_id_2d = []
train_all_label_id_3d = []
train_all_preds_2d = []
train_all_preds_3d = []
train_all_labels_2d = []
train_all_labels_3d = []
error_data_dict = {}
error_data_2d_dict = {}
error_data_3d_dict = {}
if len(device_index) > 1:
torch.distributed.barrier()
if (len(device_index) > 1 and args.local_rank == 0) or len(device_index) == 1:
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, train finished, best score: {best_val_metric_score}")
else:
logger.info(f"local rank: {device_index[0]}, train finished, best score: {best_val_metric_score}")
save_path = os.path.join(save_dir, f"final_epoch_{epoch+1}_net_id_{net_id}.pth")
if len(device_index) > 1:
torch.save(model.module.state_dict(), save_path)
else:
torch.save(model.state_dict(), save_path)
if len(device_index) > 1:
logger.info(f"local rank: {args.local_rank}, final model saved to {save_path}")
else:
logger.info(f"local rank: {device_index[0]}, final model saved to {save_path}")
def train_ddp(model, config, args):
"""
多卡分布训练模型
"""
try:
device_index = config['device_index']
if len(device_index) > 1:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
else:
device = torch.device(f"cuda:{device_index[0]}")
print(f"train_ddp, device: {device_index}")
if len(device_index) > 1:
if torch.distributed.is_available() and not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="nccl",
init_method="env://"
)
world_size = torch.distributed.get_world_size()
print(f"train_ddp, 多卡, local_rank: {args.local_rank}, world_size: {world_size}")
print(f"train_ddp, 多卡, device: {device}")
else:
world_size = 1
print(f"train_ddp, 单卡, local_rank: {device_index[0]}, world_size: {world_size}")
print(f"train_ddp, 单卡, device: {device}")
set_seed()
start_train_ddp(model, config, args)
except Exception as e:
print(f"error in train_ddp: {traceback.format_exc()}")
raise e
finally:
if len(device_index) > 1 and torch.distributed.is_initialized():
print(f"finished ddp training, destroy process group")
torch.distributed.destroy_process_group()
else:
print(f"finished single gpu training")
def print_init_params(model, string ="s3d", logger=None):
for name, param in model.named_parameters():
if len(param.shape) > 3:
if logger is not None:
logger.info(f"{string}, name: {name}, param: {param[0][0][0][:1]}")
else:
print(f"{string}, name: {name}, param: {param[0][0][0][:1]}")
elif len(param.shape) > 2:
if logger is not None:
logger.info(f"{string}, name: {name}, param: {param[0][0][:1]}")
else:
print(f"{string}, name: {name}, param: {param[0][0][:1]}")
elif len(param.shape) > 1:
if logger is not None:
logger.info(f"{string}, name: {name}, param: {param[0][:1]}")
else:
print(f"{string}, name: {name}, param: {param[0][:1]}")
else:
if logger is not None:
logger.info(f"{string}, name: {name}, param: {param[:1]}")
else:
print(f"{string}, name: {name}, param: {param[:1]}")
# print(f"{string}, name: {name}, param.shape: {param.shape}")
def get_train_config():
from pytorch_train.train_2d3d_config import node_net_train_file_dict
config = {}
net_id = "2d3d"
epochs = 20
local_rank_index = 0
'''
1010_1020_2011_2021_2041_2031
1020_2011_2021_2031
2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041
2011_2021_2031_2041_1010_1020_1016
2011_2021_2031_2041
'''
node_id = "2021_2031"
date_id = "20241217"
task_info = f"train_{net_id}_{node_id}_{date_id}_rotate_10_ddp_count_00000005"
train_log_dir = "/df_lung/ai-project/cls_train/log/train"
save_dir = "/df_lung/ai-project/cls_train/cls_ckpt"
save_folder = f"{net_id}_{node_id}_{date_id}"
log_file = os.path.join(train_log_dir, f"{task_info}.log")
save_dir = os.path.join(save_dir, save_folder)
train_data_dir = "/df_lung/cls_train_data/train_csv_data/"
train_csv_file = None
val_csv_file = None
test_csv_file = None
train_2d3d_data_2d_csv_file = None
val_2d3d_data_2d_csv_file = None
test_2d3d_data_2d_csv_file = None
train_2d3d_data_3d_csv_file = None
val_2d3d_data_3d_csv_file = None
test_2d3d_data_3d_csv_file = None
if net_id == "2d3d":
train_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_train_file"]
val_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_val_file"]
test_2d3d_data_2d_csv_file = node_net_train_file_dict[(node_id, net_id)]["2d_test_file"]
train_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_train_file"]
val_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_val_file"]
test_2d3d_data_3d_csv_file = node_net_train_file_dict[(node_id, net_id)]["3d_test_file"]
else:
train_csv_file = node_net_train_file_dict[(node_id, net_id)]["train_file"]
val_csv_file = node_net_train_file_dict[(node_id, net_id)]["val_file"]
test_csv_file = node_net_train_file_dict[(node_id, net_id)]["test_file"]
s3d_file = "/df_lung/cls_train_data/encoder_3d_classifier_state_dict.pth"
resnet_3d_file = "/df_lung/cls_train_data/resnet_classifier_state_dict.pth"
d2d_dir = "/df_lung/cls_train_data/dino2"
config['s3d_file'] = s3d_file
config['resnet_3d_file'] = resnet_3d_file
config['d2d_dir'] = d2d_dir
logger = get_logger(log_file)
logger.info(f"train_config, node_id: {node_id}")
if net_id == "2d3d":
model = Net2d3d()
batch_size = 6 if local_rank_index != None else 24
elif net_id == "2d":
model = Net2d()
batch_size = 200
elif net_id == "3d":
model = Net3d()
batch_size = 6
elif net_id == "s3d":
model = NetS3d()
model.load_state_dict(torch.load(config['s3d_file'], weights_only=False, map_location="cpu"))
batch_size = 1
elif net_id == "resnet3d":
model = NetResNet3d(Bottleneck, [3, 4, 23, 3], 128, 128, 128)
model.load_state_dict(torch.load(config['resnet_3d_file'], weights_only=True))
batch_size = 10
elif net_id == "d2d":
model = NetD2d(pretrain_dir=config["d2d_dir"])
batch_size = 1
print_init_params(model, string="get_train_config, instance model", logger=logger)
if net_id not in ["s3d", "resnet3d", "d2d"]:
init_modules(model)
if net_id == "2d3d" and node_id == "2021_2031":
step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2021_2031_20241207/train_epoch_1_net_2d3d.pt"
model.load_state_dict(torch.load(step_file, weights_only=True))
logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}")
elif net_id == "2d3d" and node_id == "2041_2031":
step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_2041_2031_20241207/train_epoch_1_step_17000_net_2d3d.pt"
model.load_state_dict(torch.load(step_file, weights_only=True))
logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}")
elif net_id == "2d3d" and node_id == "1010_1020_2011_2021_2041_2031":
step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_1010_1020_2011_2021_2041_2031_20241207/train_epoch_1_step_17000_net_2d3d.pt"
model.load_state_dict(torch.load(step_file, weights_only=True))
logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}")
elif net_id == "2d3d" and node_id == "1010_1020_2011_2021_2031_2041":
step_file = "/df_lung/ai-project/cls_train/cls_ckpt/2d3d_1010_1020_2011_2021_2031_2041_20241207/train_epoch_1_step_17000_net_2d3d.pt"
model.load_state_dict(torch.load(step_file, weights_only=True))
logger.info(f"net_id: {net_id}, node_id: {node_id}, step_file: {step_file}")
Path(train_log_dir).mkdir(parents=True, exist_ok=True)
Path(save_dir).mkdir(parents=True, exist_ok=True)
config['net_id'] = net_id
config['local_rank_index'] = local_rank_index
config['lr'] = 1e-5
config['weight_decay'] = 1e-4
config['logger'] = logger
config['criterion'] = nn.BCELoss()
config['step_size'] = 200
config['learning_rate_drop'] = 0.1
config['optimizer'] = Adam(model.parameters(), lr=config['lr'])
config['scheduler'] = lr_scheduler.StepLR(config['optimizer'], step_size=config['step_size'], gamma=config['learning_rate_drop'])
config['num_epochs'] = epochs
config['device_index'] = [config['local_rank_index']] if config['local_rank_index'] != None else [0, 1, 2, 3]
config['num_workers'] = 1
config['train_batch_size'] = batch_size
config['val_batch_size'] = batch_size
config['save_dir'] = save_dir
config['val_interval'] = 1000000000000000
config['pos_label'] = 'positive'
config['neg_label'] = 'negative'
config['train_2d3d_data_2d_csv_file'] = train_2d3d_data_2d_csv_file
config['val_2d3d_data_2d_csv_file'] = val_2d3d_data_2d_csv_file
config['test_2d3d_data_2d_csv_file'] = test_2d3d_data_2d_csv_file
config['train_2d3d_data_3d_csv_file'] = train_2d3d_data_3d_csv_file
config['val_2d3d_data_3d_csv_file'] = val_2d3d_data_3d_csv_file
config['test_2d3d_data_3d_csv_file'] = test_2d3d_data_3d_csv_file
config['train_csv_file'] = train_csv_file
config['val_csv_file'] = val_csv_file
config['test_csv_file'] = test_csv_file
config['val_metric'] = 'f1-score'
config['threshold'] = 0.5
config['logger_train_cls_report_flag'] = False
config['train_on_error_data_flag'] = False
config['train_on_error_data_epoch_dict_list'] = [
{"train_epoch": config['num_epochs']//2,
"error_train_epoch": 1,
"positive_threshold": 0.3,
"negative_threshold": 0.7,
"error_train_batch_size": 1
},
{"train_epoch": config['num_epochs']-10,
"error_train_epoch": 1,
"positive_threshold": 0.5,
"negative_threshold": 0.5,
"error_train_batch_size": 1}
]
return model, config
if __name__ == "__main__":
try:
model, config = get_train_config()
parser = argparse.ArgumentParser()
parser.add_argument("--local-rank", type=int, default=-1, dest="local_rank")
parser.add_argument("--use_zero", action="store_true")
args = parser.parse_args()
print(f"args: {args}")
train_ddp(model, config, args)
except Exception as e:
print(f"error in main: {traceback.format_exc()}")
raise e
finally:
if torch.distributed.is_initialized():
print("finished, destroy process group")
torch.distributed.destroy_process_group()
# python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 pytorch_train/train_2d3d.py --use_zero
# python pytorch_train/train_2d3d.py
node_net_train_file_dict = {
("2021_2031", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_3d_test.csv",
"batch_size": 100,
},
("2021_2031", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d_test.csv",
"batch_size": 100,
},
("2021_2031", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_2d_train_error_data_epoch_1.csv",
# "2d_train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_3d_train_error_data_epoch_1.csv",
# "3d_train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("2021_2031", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_s3d_test.csv",
"batch_size": 100,
},
("2021_2031", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2021_2031_net_id_d2d_test.csv",
"batch_size": 100,
},
("2041_2031", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_3d_test.csv",
"batch_size": 100,
},
("2041_2031", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d_test.csv",
"batch_size": 100,
},
("2041_2031", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("2041_2031", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_s3d_test.csv",
"batch_size": 100,
},
("2041_2031", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2041_2031_net_id_d2d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2041_2031", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_3d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2041_2031", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2041_2031", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2041_2031", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_s3d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2041_2031", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2041_2031_net_id_d2d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2031_2041", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_3d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2031_2041", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2031_2041", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2031_2041", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_s3d_test.csv",
"batch_size": 100,
},
("1010_1020_2011_2021_2031_2041", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1010_1020_2011_2021_2031_2041_net_id_d2d_test.csv",
"batch_size": 100,
},
("1020_2011_2021_2031", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_3d_test.csv",
"batch_size": 100,
},
("1020_2011_2021_2031", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d_test.csv",
"batch_size": 100,
},
("1020_2011_2021_2031", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("1020_2011_2021_2031", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_s3d_test.csv",
"batch_size": 100,
},
("1020_2011_2021_2031", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/1020_2011_2021_2031_net_id_d2d_test.csv",
"batch_size": 100,
},
("2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_3d_test.csv",
"batch_size": 100,
},
("2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d_test.csv",
"batch_size": 100,
},
("2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_s3d_test.csv",
"batch_size": 100,
},
("2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2031_2021_2041", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2046_2047_2048_2060_2061_2062_3001_4001_5001_6001_2021_2031_2041_net_id_d2d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041_1010_1020_1016", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_3d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041_1010_1020_1016", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041_1010_1020_1016", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041_1010_1020_1016", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_s3d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041_1010_1020_1016", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_1010_1020_1016_net_id_d2d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041", "3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_3d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041", "2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041", "2d3d"):{
"2d_train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d3d_data_2d_train.csv",
"2d_val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d3d_data_2d_val.csv",
"2d_test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d3d_data_2d_test.csv",
"3d_train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d3d_data_3d_train.csv",
"3d_val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d3d_data_3d_val.csv",
"3d_test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_2d3d_data_3d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041", "s3d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_s3d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_s3d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_s3d_test.csv",
"batch_size": 100,
},
("2011_2021_2031_2041", "d2d"):{
"train_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_d2d_train.csv",
"val_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_d2d_val.csv",
"test_file": "/df_lung/cls_train_data/train_csv_data/2011_2021_2031_2041_net_id_d2d_test.csv",
"batch_size": 100,
},
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment