Commit 329f1f6c authored by exp's avatar exp

分类训练

parents
Pipeline #488 canceled with stages
import numpy as np
#from matplotlib_inline import backend_inline
from matplotlib import pyplot as plt
import os
import torch
class Accumulator:
"""在n个变量上累加"""
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class Animator:
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
figsize=(5, 5)):
if legend is None:
legend = []
#仅仅是为了将格式改为scg便于在juoyter上显示
#self.use_svg_display()
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
#print(self.axes)
if nrows * ncols == 1:
self.axes = [self.axes, ]
print('成功')
print(self.axes)
#使用lambda函数捕获参数
print(self.axes[0])
self.config_axes = lambda: self.set_axes(
axes=self.axes[0], xlabel=xlabel, ylabel=ylabel,
xlim=xlim, ylim=ylim, xscale=xscale, yscale=yscale, legend=legend)
self.X, self.Y, self.fmts = None, None, fmts
'''
def use_svg_display(self):
"""
使用svg格式在jupyter中显示绘图
"""
backend_inline.set_matplotlib_formats('svg')
'''
def set_axes(self, axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
#print(axes)
axes.set_xlabel(xlabel), axes.set_ylabel(ylabel)
axes.set_xscale(xscale), axes.set_yscale(yscale)
axes.set_xlim(xlim), axes.set_ylim(ylim)
if legend:
axes.legend(legend)
axes.grid()
def add(self, x, y):
#向图表中添加多个数据点
#hasattr(object, name) 用于判断对象是否包含对应的属性
if not hasattr(y, '__len__'):
y = [y]
n = len(y)
if not hasattr(x, '__len__'):
x = [x] * n
if not self.X:
self.X = [[] for _ in range(n)]
if not self.Y:
self.Y = [[] for _ in range(n)]
for i, (a, b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.X[i].append(a)
self.Y[i].append(b)
#用于清空axes对象的所有绘图元素
self.axes[0].cla()
for x, y, fmt in zip(self.X, self.Y, self.fmts):
print(x,' ', y)
self.axes[0].plot(x, y, fmt)
self.config_axes()
print('更新一次')
#self.fig.canvas.draw()
#self.fig.canvas.flush_events()
#plt.draw()
plt.pause(2)
#npy文件中的数据加载出来
def load_npy_to_data(input_file=None):
original_data = np.load(input_file)
return original_data
#将一个固定的npy数组以图片的形式显示出来
def show_img():
input_file = r'/home/lung/project/ai-project/cls_train/data/train_data/plus_3d_0818/npy_data/cls_2047/495.npy'
data = load_npy_to_data(input_file)
#print(data)
data = data[29].astype(np.float32)
print(data)
plt.imshow(data, cmap='gray')
plt.show()
#索引为24的数据是中心面
def show_img_2d():
input_file = r'/home/lung/ai-project/cls_train/data/train_data/plus_0512/npy_data/cls_2021/1893_10.npy'
data = load_npy_to_data(input_file)
data = data.astype(np.float32)
plt.imshow(data, cmap='gray')
plt.show()
#测试数据
def run():
test = [[0.1, 0.2], [0.3, 0.6]]
thred = 0.5
test = torch.tensor(test)
result = test > thred
print(result.float())
#print(test > thred)
def test_sum():
test = [[[1, 1, 2], [2, 3, 4]],
[[2, 0, 0], [9, 9, 9]]]
test = np.array(test)
print(np.sum(test, axis=0))
#获取多维数组中值大于指定值的索引
def test():
mask = [[[ True, True, True],
[False, False, False],
[ True, True, True]],
[[ True, False, True],
[False, False, False],
[ True, True, True]]]
mask = np.array(mask)
indices = np.asarray(np.where(mask == 1))
print(indices)
print(indices.min(axis=1))
print(indices.max(axis=1))
#将两个图像取差值,最后输出结果图像查看
def check_image():
for index in range(22, 37):
data_1 = load_npy_to_data(f'/home/lung/project/ai-project/cls_train/log/npy/01/{index}.npy')
data_2 = load_npy_to_data(f'/home/lung/project/ai-project/cls_train/log/npy/03/{index}.npy')
#data = np.where(data_1 == data_2, data_1, 0)
data = data_2 - data_1
plt.imsave(f'/home/lung/project/ai-project/cls_train/log/image/temp_02/{index}.png', data, cmap='gray')
if __name__ == '__main__':
#test_sum()
#check_image()
show_img()
import random
import itertools
import numpy as np
def generate_general_indexs():
indexes = [[0, 2],
[1, 3],
[2, 0],
[3, 1],
[4, 6],
[5, 7],
[6, 4],
[7, 5]]
return indexes
def generate_general_permute_keys():
rotate_y = 0
transpose = 0
keys = [((rotate_y, 0), 0, 0, 0, transpose),
((rotate_y, 1), 0, 0, 0, transpose),
((rotate_y, 0), 0, 1, 1, transpose),
((rotate_y, 1), 0, 1, 1, transpose),
((rotate_y, 0), 1, 0, 0, transpose),
((rotate_y, 1), 1, 0, 0, transpose),
((rotate_y, 0), 1, 1, 1, transpose),
((rotate_y, 1), 1, 1, 1, transpose)]
return keys
def random_permute_key():
"""
Generate and randomly selects a permute key
"""
return random.choice(list(generate_general_permute_keys()))
def generate_permutation_keys():
"""
This function returns a set of "keys" that represent the 48 unique rotations &
reflections of a 3D matrix.
Each item of the set is a tuple:
((rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)
As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
transposed.
48 unique rotations & reflections:
https://en.wikipedia.org/wiki/Octahedral_symmetry#The_isometries_of_the_cube
"""
return set(itertools.product(
itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))
def random_permutation_key():
"""
Generates and randomly selects a permutation key. See the documentation for the
"generate_permutation_keys" function.
"""
return random.choice(list(generate_permutation_keys()))
def augment_data(data):
padded_data = -1000 * np.ones((256,256,256))
padded_data[104:152, :, :] = data
data = padded_data.transpose(1, 2, 0)
return data
def permute_data(data, key):
"""
Permutes the given data according to the specification of the given key. Input data
must be of shape (n_modalities, x, y, z).
Input key is a tuple: (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)
As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
transposed.
"""
data = np.copy(data)
#print(data.shape)
(rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose = key
if rotate_y != 0:
data = np.rot90(data, rotate_y, axes=(1, 3))
if rotate_z != 0:
data = np.rot90(data, rotate_z, axes=(2, 3))
if flip_x:
data = data[:, ::-1]
if flip_y:
data = data[:, :, ::-1]
if flip_z:
data = data[:, :, :, ::-1]
if transpose:
for i in range(data.shape[0]):
data[i] = data[i].T
return data
def random_permutation_data(data):
key = random_permute_key()
return permute_data(data, key)
\ No newline at end of file
This diff is collapsed.
import os
import numpy as np
"""
Describe:
根据传过来的中心点对数据进行切割
Parameters:
data: 所有的切面数据
crop_start: 在切面中选择出来的大小为(48, 256, 256)区域的三个维度的中心点的坐标
"""
def get_crop_data(data, crop_start, crop_size, mode='constant', constant_values=-1000):
if data is None:
return data
data_shape = np.array(data.shape)
crop_data = data[
max(crop_start[0], 0):min(crop_start[0] + crop_size[0], data_shape[0]),
max(crop_start[1], 0):min(crop_start[1] + crop_size[1], data_shape[1]),
max(crop_start[2], 0):min(crop_start[2] + crop_size[2], data_shape[2])]
# pad = np.zeros((3, 2), np.int)
pad = np.zeros((3, 2), np.int32)
#np.maximum(x1, x2) 用于逐个比较其中元素的大小,最后返回的是较大的形成的数据
#往前需要填充多少数据
pad[:, 0] = np.maximum(-crop_start, 0)
pad[:, 1] = np.maximum(crop_start + crop_size - data_shape, 0)
if not np.all(pad == 0):
#pad表示在每个维度上要填充的长度
if mode == 'edge':
crop_data = np.pad(crop_data, pad, mode=mode)
else:
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
return crop_data
def get_crop_data_padding_2d_opt(ct_data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
# 获取原始数据
data = ct_data.get_raw_image()
data_shape = np.array(data.shape)
crop_size = np.array(crop_size)
crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
crop_data = get_crop_data(data=data,
crop_start=crop_start,
crop_size=crop_size,
mode=mode,
constant_values=constant_values)
return crop_data
"""
只保留一层,将该层数据放置中心点,别的层都填充为-1000
crop_size=[48, 256, 256]
"""
def get_crop_data_padding(ct_data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
data = ct_data.get_raw_image()
data_shape = np.array(data.shape)
crop_size = np.array(crop_size)
z_index = int(select_box[0][0])
crop_start = np.int0((select_box[1:, 0] + select_box[1:, 1])//2 - crop_size[1:] // 2)
crop_data = data[z_index:z_index+1,
max(crop_start[0], 0) : min(crop_start[0] + crop_size[1], data_shape[1]),
max(crop_start[1], 0) : min(crop_start[1] + crop_size[2], data_shape[2])]
# pad = np.zeros((3, 2), np.int)
pad = np.zeros((3, 2), np.int32)
pad[1:, 0] = np.maximum(-crop_start, 0)
pad[1:, 1] = np.maximum(crop_start + crop_size[1:] - data_shape[1:], 0)
if not np.all(pad == 0):
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
result_data = np.full(crop_size, -1000)
result_data[crop_size[0]//2] = crop_data
return result_data
#从当前(1024, 1024)的图片数据中获取到指定坐标区域的数据,如果所在区域不够(256, 256),则需要相对应的补充-1000
#select_box=[[z_index, z_index], [y_min, y_max], [x_min, x_max]]
#crop_size=[256. 256]
def get_crop_data_2d(data=None, select_box=None, crop_size=None, mode='constant', constant_values=-1000):
data_shape = np.array(data.shape)
crop_size = np.array(crop_size)
crop_start = np.int0((select_box[1:, 0] + select_box[1:, 1])//2 - crop_size[:]//2)
crop_data = data[max(crop_start[0], 0) : min(crop_start[0] + crop_size[0], data_shape[0]),
max(crop_start[1], 0) : min(crop_start[1] + crop_size[1], data_shape[1])]
# pad = np.zeros((2, 2), np.int)
pad = np.zeros((2, 2), np.int32)
pad[:, 0] = np.maximum(-crop_start, 0)
pad[:, 1] = np.maximum(crop_start + crop_size[:] - data_shape[:], 0)
if not np.all(pad == 0):
crop_data = np.pad(crop_data, pad, mode=mode, constant_values=constant_values)
"""result_data = np.full(crop_size, -1000)
result_data[crop_size[0]//2] = crop_data"""
result_data = crop_data
return result_data
'''
#将对切面数据继续处理,返回(48, 256, 256)大小的数据
def crop_ct_data(ct_data=None, select_box=None, crop_size=None):
data = ct_data.get_raw_image()
#print(data.shape)
spacing = ct_data.get_raw_spacing()
z_center = int((select_box[0, 0] + select_box[0, 1]) // 2)
crop_size = np.array(crop_size)
#以下操作包含原box的情况下,并z_center处于中间位置
#z_diff=crop_size[0] - (select_box[0, 1] - select_box[0, 0])
z_diff = crop_size[0] - (select_box[0, 1] - select_box[0, 0]) - 1
#所选择的层数小于48
if z_diff > 0:
#新增前面层数
add_front_z = 0
#新增后面层数
add_behind_z = 0
#select_box所对应的中间层距离z_min相差的层数
z_front_num = z_center - select_box[0, 0]
#select_box所对应的中间层距离z_max相差的层数
z_behind_num = select_box[0, 1] - z_center
#这里add_front_z为1或0
add_front_z = min(z_behind_num - z_front_num, z_diff)
#当z_diff != 1 时
if add_front_z < z_diff:
add_front_z = add_front_z + (z_diff - add_front_z) // 2
add_behind_z = z_diff - add_front_z
else:
add_front_z = add_front_z
add_behind_z = 0
select_box[0, 0] = select_box[0, 0] - add_front_z
select_box[0, 1] = select_box[0, 1] + add_behind_z
scale = (1, 1, 1)
#scale = (1, 1, 1) if spacing[0] > 0.5 else (2, 1, 1)
crop_size = np.int0(crop_size * scale)
crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
#根据尺寸大小对数据进行切割
original_data = get_crop_data(data=data, crop_start=crop_start, crop_size=crop_size)
return original_data
'''
def crop_ct_data(ct_data=None, select_box=None, crop_size=None):
data = ct_data.get_raw_image()
#print(data.shape)
spacing = ct_data.get_raw_spacing()
z_center = int((select_box[0, 0] + select_box[0, 1]) // 2)
crop_size = np.array(crop_size)
#以下操作包含原box的情况下,并z_center处于中间位置
#z_diff=crop_size[0] - (select_box[0, 1] - select_box[0, 0])
z_diff = crop_size[0] - (select_box[0, 1] - select_box[0, 0])
#所选择的层数小于48
if z_diff > 0:
#新增前面层数
add_front_z = 0
#新增后面层数
add_behind_z = 0
#select_box所对应的中间层距离z_min相差的层数
z_front_num = z_center - select_box[0, 0]
#select_box所对应的中间层距离z_max相差的层数
z_behind_num = select_box[0, 1] - z_center
if z_behind_num >= z_front_num > 0:
add_front_z = min(z_behind_num - z_front_num, z_diff)
if add_front_z < z_diff:
add_front_z = add_front_z + (z_diff - add_front_z) // 2
add_behind_z = z_diff - add_front_z
elif z_front_num > z_behind_num > 0:
add_behind_z = min(z_front_num - z_behind_num, z_diff)
if add_behind_z < z_diff:
add_behind_z = add_behind_z + (z_diff - add_behind_z) // 2
add_front_z = z_diff - add_behind_z
select_box[0, 0] = select_box[0, 0] - add_front_z
select_box[0, 1] = select_box[0, 1] + add_behind_z
scale = (1, 1, 1)
#scale = (1, 1, 1) if spacing[0] > 0.5 else (2, 1, 1)
crop_size = np.int0(crop_size * scale)
crop_start = np.int0((select_box[:, 0] + select_box[:, 1])//2 - crop_size // 2)
#根据尺寸大小对数据进行切割
original_data = get_crop_data(data=data, crop_start=crop_start, crop_size=crop_size)
return original_data
#将自动分割出来的数据都将其每个切面都作为中心面,然后其余地方进行填充成(48, 256, 256)的数据进行预测
\ No newline at end of file
import pathlib
import sys
import os
current_file = pathlib.Path(__file__).resolve()
project_root = current_file.parent
project_dir_name = 'cls_train'
while project_root.name != project_dir_name and project_root != project_root.parent:
project_root = project_root.parent
if project_root.name != project_dir_name:
raise Exception(f"没有找到项目路径: {project_dir_name}")
sys.path.append(str(project_root))
import argparse
import os
import sys
from pathlib import Path
from loguru import logger # 导入 loguru
def get_logger(log_file, rotation="500 MB", compression="zip", enqueue=True):
log_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <4} | [{module}].[{function}]:{line} - {message}"
logger.remove()
logger.add(log_file, format=log_format, rotation=rotation, compression=compression, enqueue=enqueue)
return logger
import pathlib
import sys
def setup_project_path(project_name='cls_train'):
cur_path = pathlib.Path(__file__).parent.resolve()
while cur_path.name != project_name:
cur_path = cur_path.parent
if cur_path.name != project_name:
raise ValueError(f"找不到项目根目录: {project_name}")
sys.path.append(str(cur_path))
return cur_path
import os
import logging
import time
import numpy as np
import scipy.ndimage
import SimpleITK as sitk
import pydicom
from cls_utils.utils import check_and_makedirs
from collections import namedtuple
NoduleBox = namedtuple('NoduleBox', ['z', 'y', 'x', 'diameter', 'uid', 'dicom_path'])
def get_cts(bundle, dicom_path):
cts = CTSeries()
start_time = time.time()
#读取dicom文件
cts.load_dicoms(dicom_path)
print('start load dicoms time {:.5f}(s)'.format(time.time() - start_time))
return cts
def get_nodule_y_pixel_length(spacing, nodule_box):
return int(np.ceil(nodule_box.diameter / spacing[1]))
def get_diameter_pixel_length(spacing, diameter):
length = np.zeros(3, np.int)
for i in range(3):
length[i] = int(np.ceil(diameter / spacing[i]))
return length
def get_nodule_rect(spacing, nodule_box):
diameter_pixel_length = get_diameter_pixel_length(spacing, nodule_box.diameter)
center = np.array([nodule_box.z, nodule_box.y, nodule_box.x])
rect = np.zeros((3, 2), np.int)
for i in range(3):
rect[i][0] = center[i] - diameter_pixel_length[i] // 2
rect[i][1] = rect[i][0] + diameter_pixel_length[i]
return rect
def resample(data, spacing, new_spacing=[1.0, 1.0, 1.0], order=1):
if data is None:
return None
new_shape = np.round(data.shape * spacing / new_spacing)
resample_spacing = spacing * data.shape / new_shape
resize_factor = new_shape / data.shape
new_data = scipy.ndimage.interpolation.zoom(data, resize_factor, mode='nearest', order=order)
return new_data, resample_spacing
def resample_nodule_box(nodule_box, spacing, new_spacing=[1.0, 1.0, 1.0]):
if nodule_box is None:
return None
z = int(np.ceil(nodule_box.z * spacing[0] / new_spacing[0]))
y = int(np.ceil(nodule_box.y * spacing[1] / new_spacing[1]))
x = int(np.ceil(nodule_box.x * spacing[2] / new_spacing[2]))
new_box = NoduleBox(int(z), int(y), int(x), nodule_box.diameter, nodule_box.uid, nodule_box.dicom_path)
return new_box
def maxip(input, level_num=2):
output = np.zeros(input.shape, input.dtype)
for i in range(len(input)):
length = level_num if i >= level_num else i
output[i] = np.max(input[i-length:i+level_num+1], axis=0)
return output
def minip(input, level_num=2):
output = np.zeros(input.shape, input.dtype)
for i in range(len(input)):
length = level_num if i >= level_num else i
output[i] = np.min(input[i-length:i+level_num+1], axis=0)
return output
class CTSeries(object):
def __init__(self):
self._SeriesInstanceUID = None
self._SOPInstanceUIDs = None
self._raw_image = None
self._raw_origin = None
self._raw_spacing = None
self._raw_direction = None
self._dicoms_is_loaded = False
def load_dicoms(self, folder_path):
logging.info('{}, Loading dicoms from {}...'.format(
time.strftime('%Y-%m-%d %H:%M:%S'), folder_path))
#print(folder_path)
dicom_names = [f for f in os.listdir(folder_path) if '.xml' not in f]
dicom_paths = list(map(lambda x: os.path.join(folder_path, x), dicom_names))
dicoms = list(map(lambda x: pydicom.read_file(x), dicom_paths))
# slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
# sort slices by their z coordinates from large to small
# idx_z_sorted = np.argsort(slice_locations)[::-1]
try:
slice_locations = list(map(lambda x: float(x.ImagePositionPatient[2]), dicoms))
except AttributeError:
try:
slice_locations = list(map(lambda x: float(x.SliceLocation), dicoms))
except AttributeError:
slice_locations = []
for i in range(len(dicoms)):
try:
slice_locations.append(float(dicoms[i].ImagePositionPatient[2]))
except AttributeError:
print(i, dicoms[i].SeriesInstanceUID)
patient_position = dicoms[0].PatientPosition
self._SeriesInstanceUID = dicoms[0].SeriesInstanceUID
if patient_position in ['FFP', 'FFS']:
idx_z_sorted = np.argsort(slice_locations)[::-1]
else:
idx_z_sorted = np.argsort(slice_locations)[::-1]
#将dicom文件按照指定的idx_z_sorted进行重排
dicoms = list(map(lambda x: dicoms[x], idx_z_sorted))
self._SOPInstanceUIDs = np.array(list(map(lambda x: x.SOPInstanceUID, dicoms)))
#dicom_path_before = np.array(dicom_paths)
#print(dicom_path_before)
dicom_paths = np.array(dicom_paths)[idx_z_sorted]
#print(dicom_paths)
#print(self._SOPInstanceUIDs)
#dicoms = np.array(dicoms)[idx_z_sorted]
reader = sitk.ImageSeriesReader()
reader.SetFileNames(dicom_paths)
image_itk = reader.Execute()
# all in [z, y, x] order
self._raw_image = sitk.GetArrayFromImage(image_itk)
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_direction = image_itk.GetDirection()
# print('raw_image', self._raw_image.shape, 'raw_spacing', self._raw_spacing)
# print('raw_origin', self._raw_origin, 'raw_direction', self._raw_direction)
self._dicoms_is_loaded = True
def load_single_file(self, file_path):
logging.info('{}, Loading file from {}...'.format(
time.strftime('%Y-%m-%d %H:%M:%S'), file_path))
image_itk = sitk.ReadImage(file_path)
# all in [z, y, x] order
self._raw_image = sitk.GetArrayFromImage(image_itk)
self._raw_origin = np.array(list(reversed(image_itk.GetOrigin())))
self._raw_spacing = np.array(list(reversed(image_itk.GetSpacing())))
self._raw_direction = image_itk.GetDirection()
# print('raw_image', self._raw_image.shape, 'raw_spacing', self._raw_spacing)
# print('raw_origin', self._raw_origin, 'raw_direction', self._raw_direction)
self._dicoms_is_loaded = True
def get_raw_image(self):
return self._raw_image
def set_raw_image(self, data):
self._raw_image = data
def get_raw_origin(self):
return self._raw_origin
def get_raw_spacing(self):
return self._raw_spacing
def get_raw_image_affine(self):
affine = np.diag(list(self._raw_spacing) + [1])
affine[:, 3][:3] = np.array(list(self._raw_origin))
return affine
def save_raw_image(self, output_file):
if self._dicoms_is_loaded:
self.save_image(self._raw_image, output_file)
def save_image(self, data, output_file):
if self._dicoms_is_loaded:
image = sitk.GetImageFromArray(data)
image.SetSpacing(np.array(list(reversed(self._raw_spacing))))
image.SetOrigin(np.array(list(reversed(self._raw_origin))))
image.SetDirection(self._raw_direction)
check_and_makedirs(output_file)
sitk.WriteImage(image, output_file)
def transform_file_type(self, input_file, output_file):
image = sitk.ReadImage(input_file)
sitk.WriteImage(image, output_file)
This diff is collapsed.
import os
import time
from binascii import b2a_hex
import cv2
import numpy as np
#import nibabel as nib
import base64
from Crypto.Cipher import AES
def encrypt(text):
key = '1234561234561234'
cryptor = AES.new(key, AES.MODE_CBC, key)
count = len(text)
add = len(key) - (count % len(key))
text = text + ('\0' * add)
ciphertext = cryptor.encrypt(text)
return b2a_hex(ciphertext)
def decrypt_oralce(text):
key = '123456'
aes = AES.new(add_to_16(key), AES.MODE_ECB)
base64_decrypted = base64.decodebytes(text.encode(encoding='utf-8'))
decrypted_text = str(aes.decrypt(base64_decrypted), encoding='utf-8').replace('\0', '')
return decrypted_text
def add_to_16(value):
while len(value) % 16 != 0:
value += '\0'
return str.encode(value) # 返回bytesl
def index_to_list(indexs):
if indexs:
strList = indexs[:-1].split('],')
list = []
for idx, str in enumerate(strList):
stri = str[1:].split(',')
map = {'x': stri[0], 'y': stri[1]}
list.append(map)
return list
def index_to_meta(indexs):
list = index_to_list(indexs)
x = min_by_key(list, 'x')
y = min_by_key(list, 'y')
w = max_by_key(list, 'x') - min_by_key(list, 'x') + 1
h = max_by_key(list, 'y') - min_by_key(list, 'y') + 1
spot = ''.zfill(w * h)
spot_len = len(spot)
for m in list:
flag = (int(m['y']) - y) * w + (int(m['x']) - x)
spot = spot[0:flag] + '1' + spot[flag+1:spot_len]
strs = get_str_list(spot, 32)
array = '['
for s in strs:
array = array + '%d,' % (parse_unsigned_int(s, 2))
array = array[:-1] + ']'
meta = '{"x":%s,"y":%s,"w":%s,"h":%s,"delineation":%s}' % (x, y, w, h, array)
return meta
def parse_unsigned_int(s, radix):
length = len(s)
if length > 0:
firstChar = int(s[0])
if firstChar > 0:
newstr=''
for str in s[1:]:
if str == '0':
newstr = newstr + '1'
else:
newstr = newstr + '0'
return -(int(newstr, radix) + 1)
else:
return int(s, radix)
def get_str_list(instr, length):
size = len(instr) / length
if len(instr) % length != 0:
s = '00000000000000000000000000000000000'
instr = instr + s[len(instr) % length:]
size += 1
return get_str_list_size(instr, length, size)
def get_str_list_size(instr, length, size):
list = []
for i in range(int(size)):
childStr = substring(instr, i * length, (i + 1) * length)
list.append(childStr)
return list
def substring(str, f, t):
if f > len(str):
return None
if t > len(str):
return str[f:len(str)]
else:
return str[f:t]
def sub(str, p, c):
newstr = []
for s in str:
newstr.append(s)
newstr[p] = c
return ''.join(newstr)
def min_by_key(coll, key):
if coll:
candidate = coll[0][key]
for map in coll:
if int(map[key]) < int(candidate):
candidate = map[key]
return int(candidate)
def max_by_key(coll, key):
if coll:
candidate = coll[0][key]
for map in coll:
if int(map[key]) > int(candidate):
candidate = map[key]
return int(candidate)
def sigmoid(X):
return 1 / (1 + np.exp(-X))
'''
def load_image(input_file):
return nib.load(input_file)
def get_image(data, affine, nib_class=nib.Nifti1Image):
return nib_class(dataobj=data, affine=affine)
'''
def hu_value_clip(image, hu_min=-1200.0, hu_max=600.0, hu_nan=-2000.0):
image_new = np.asarray(image)
image_new[np.isnan(image_new)] = hu_nan
image_new = np.clip(image_new, hu_min, hu_max)
return image_new
#对ct原像素进行处理,将其像素值进行规范化
def hu_value_to_uint8(original_data, hu_min=-1200.0, hu_max=600.0, hu_nan=-2000.0):
image_new = np.asarray(original_data)
image_new[np.isnan(image_new)] = hu_nan
# normalize to [0, 1]
image_new = (image_new - hu_min)
image_new = image_new / (hu_max - hu_min)
image_new = np.clip(image_new, 0, 1)
image_new = (image_new * 255).astype(np.float32)
return image_new
#将数据转化到(-1,1)
def normalize(image):
image_new = np.asarray(image)
image_new = (image_new - 128.0) / 128.0
return image_new
def hu_normalize(original_data, hu_nan=-1000):
image_new = np.asarray(original_data)
image_new[np.isnan(image_new)] = hu_nan
image_new = image_new / 1000
image_new = np.clip(image_new, -1, 1)
image_new = image_new.astype(np.float32)
return image_new
def check_and_makedirs(output_file, is_file=True):
if is_file:
output_dir = os.path.dirname(output_file)
else:
output_dir = output_file
#如果文件夹下面存在文件,则将其全部删除
if os.path.exists(output_file):
os.remove(output_file)
if output_dir is not None and output_dir != '' and not os.path.exists(output_dir):
os.makedirs(output_dir)
def get_z_start_end_from_bound(data, z_center, bound_size):
z_start, z_end = max(z_center - bound_size, 0), min(z_center + bound_size, data.shape[0])
return z_start, z_end
def get_crop_start_end(start, end, crop_size):
center = (start + end) // 2
crop_start = center - np.asarray(crop_size) // 2
crop_start = np.maximum(crop_start, 0)
return crop_start, crop_start + crop_size
def print_data_mask(data, mask, filename=None, group=False):
coords = np.asarray(np.where(mask > 0))
if len(coords[0]) > 0:
new_start = coords.min(axis=1)
new_end = coords.max(axis=1) + 1
print_mask = mask[new_start[0]:new_end[0], new_start[1]:new_end[1]]
print_data = data[new_start[0]:new_end[0], new_start[1]:new_end[1]]
if filename is not None:
with open(filename, 'w') as f:
for temp_data in print_data:
f.write(np.array2string(temp_data).replace('\n', '') + '\n')
for temp_mask in print_mask:
f.write(np.array2string(temp_mask).replace('\n', '') + '\n')
if group:
for temp_data, temp_mask in zip(print_data, print_mask):
temp_group = np.array([str(i) + '(' + str(j) + ')' for i, j in zip(temp_data, temp_mask)])
f.write(np.array2string(temp_group).replace('\n', '').replace('\'', '') + '\n')
else:
for temp_data in print_data:
print(np.array2string(temp_data).replace('\n', '') + '\n')
for temp_mask in print_mask:
print(np.array2string(temp_mask).replace('\n', '') + '\n')
if group:
for temp_data, temp_mask in zip(print_data, print_mask):
temp_group = np.array([str(i) + '(' + str(j) + ')' for i, j in zip(temp_data, temp_mask)])
print(np.array2string(temp_group).replace('\n', '').replace('\'', '') + '\n')
def base64_to_list(base64_str):
indexs = ''
list = []
img_np = None
if base64_str:
time_now = time.time()
img_data = base64.b64decode(base64_str[22:])
nparr = np.fromstring(img_data, np.uint8)
img_np = cv2.imdecode(nparr, 0)
img_np[img_np != 0] = 1
point_list = np.where(img_np != 0)
if len(point_list[0]) > 0:
y_list = point_list[0]
x_list = point_list[1]
indexs = '{'
for point_idx, x in enumerate(x_list):
raw_x = x
raw_y = y_list[point_idx]
map = {'x': raw_x, 'y': raw_y}
list.append(map)
indexs = indexs + '[%s,%s],' % (raw_x, raw_y)
indexs = indexs[:-1] + '}'
#print('run time {:.5f}(s)'.format(time.time() - time_now))
return list, indexs, img_np
\ No newline at end of file
[conf]
MODEL_PATH = D:/plus_new/python/checkpoint/
ZIP_DEMO = D:/plus_new/python/service/demo.txt
AI_PATH_PREFIX = D:/plus_new/ai
HTTP_SERVER_PREFIX = http://127.0.0.1/dcm
KAFKA_SERVER = 127.0.0.1:9092
MYSQL_SERVER = mysql+pymysql://lung:lung1qaz2wsx@127.0.0.1:3306/ct_file?charset=utf8
DICOM_RAW_DATA_PATH = D:/plus_new/raw-data
REDIS_SYNC_KEY = sync-server
REDIS_SERVER = 127.0.0.1
REDIS_DB = 5
\ No newline at end of file
import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from configparser import ConfigParser
import json
#from testing.get_model import get_all_model
"""parser = argparse.ArgumentParser(description='load env')
parser.add_argument('--profile', default='dev', type=str,
help='profile information')
args = parser.parse_args()
profile = args.profile"""
ws_cfg = ConfigParser()
ws_cfg.read(os.path.dirname(os.path.abspath(__file__)) + '/app-qa.properties')
"""if profile == 'qa':
ws_cfg.read(os.path.dirname(os.path.abspath(__file__)) + '/app-qa.properties')
else:
ws_cfg.read((os.path.dirname(os.path.abspath(__file__)) + '/app-dev.properties'))"""
# global var
MYSQL_SERVER = str(ws_cfg.get('conf', 'MYSQL_SERVER'))
AI_PATH_PREFIX = str(ws_cfg.get('conf', 'AI_PATH_PREFIX'))
HTTP_SERVER_PREFIX = str(ws_cfg.get('conf', 'HTTP_SERVER_PREFIX'))
\ No newline at end of file
{
"pretrain_ckpt": "./cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2031/cls_234567_203120240815-0048.ckpt",
"csv_path" : "/home/lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_2031/train_original.csv",
"node_times" : [[2031], [2011,2021,2041]],
"pretrain_folder": "./cls_train/cls_ckpt/best_cls_replenish/10/cls_1234567_1010-1016-1020",
"dicom_folder": "/opt/lung/ai" ,
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256],
"validation_filename": "cls_train/log/validation/cls_234567_2031/20240815/log_validation_20240815.log",
"compute_accuracy_filename": "cls_train/log/accuracy/cls_234567_2031/accuracy.log"
}
\ No newline at end of file
{
"pretrain_ckpt": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2041/cls_1_204120230927-1256.ckpt",
"csv_path" : "/df_lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_2041/train.csv",
"node_times" : [[2041]],
"pretrain_folder": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_replenish/10/cls_1234567_1010-1016-1020",
"dicom_folder": "/opt/lung/ai" ,
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256],
"validation_filename": "/df_lung/ai-project/cls_train/log/validation/cls_234567_2041/20241106/log_validation_20241106.log",
"compute_accuracy_filename": "/df_lung/ai-project/cls_train/log/accuracy/cls_234567_2041/20241106/accuracy.log"
}
\ No newline at end of file
{
"pretrain_ckpt": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2041/cls_234567_204120240105-0248.ckpt",
"csv_path" : "/df_lung/ai-project/cls_train/data/train_data/plus_3d_0818/subject_all_csv/08/cls_234567_2041/train.csv",
"node_times" : [[2041]],
"pretrain_folder": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_replenish/10/cls_1234567_1010-1016-1020",
"dicom_folder": "/opt/lung/ai" ,
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256],
"validation_filename": "/df_lung/ai-project/cls_train/log/validation/cls_234567_2041/20241112/log_validation_20241112.log",
"compute_accuracy_filename": "/df_lung/ai-project/cls_train/log/accuracy/cls_234567_2041/20241112/accuracy.log"
}
\ No newline at end of file
{
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256] ,
"dicom_folder": "/opt/lung/ai" ,
"train_data_path": "./cls_train/data/train_data/plus_3d_0818" ,
"npy_folder": "npy_data",
"csv_path": "subject_all_csv" ,
"subject_all_csv": "subject_all.csv",
"image_path": "train_image",
"n_channels": 1 ,
"n_diff_classes": 1 ,
"n_base_filters": 16 ,
"batch_size": 8,
"lr": 1e-4,
"momentum": 0.9,
"epoch": 1000 ,
"patience": 200 ,
"learning_rate_drop": 0.1,
"early_stop": 1000 ,
"train_csv_file": "" ,
"ckpt_save_path": "./cls_train/cls_ckpt/best_cls_3d_1025/" ,
"ckpt_pretrain_path": "./cls_train/best_cls_3d_1_5001-6001_0824/temp",
"ckpt_file": "cls_1_5001-6001/cls_1_5001-600120230825-0208.ckpt",
"training_filename": "cls_train/log/train/log_training_3d_234567_2031_20240725.log"
}
{
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256] ,
"dicom_folder": "/opt/lung/ai" ,
"train_data_path": "./cls_train/data/train_data/plus_3d_1104" ,
"npy_folder": "npy_data",
"csv_path": "subject_all_csv" ,
"subject_all_csv": "subject_all.csv",
"image_path": "train_image",
"n_channels": 1 ,
"n_diff_classes": 1 ,
"n_base_filters": 16 ,
"batch_size": 8,
"lr": 1e-4,
"momentum": 0.9,
"epoch": 1000 ,
"patience": 200 ,
"learning_rate_drop": 0.1,
"early_stop": 1000 ,
"train_csv_file": "" ,
"ckpt_save_path": "./cls_train/cls_ckpt/best_cls_3d_1104/" ,
"ckpt_pretrain_path": "./cls_train/best_cls_3d_1_5001-6001_0824/temp",
"ckpt_file": "cls_1_5001-6001/cls_1_5001-600120230825-0208.ckpt",
"training_filename": "cls_train/log/train/log_training_3d_234567_2031_20241104.log"
}
{
"train_crop_size": [48, 256, 256],
"train_crop_size_2d": [256, 256] ,
"dicom_folder": "/opt/lung/ai" ,
"train_data_path": "data/train_data/plus_3d_0818" ,
"npy_folder": "npy_data",
"csv_path": "subject_all_csv" ,
"subject_all_csv": "subject_all.csv",
"image_path": "train_image",
"n_channels": 1 ,
"n_diff_classes": 1 ,
"n_base_filters": 16 ,
"batch_size": 8,
"lr": 1e-4,
"momentum": 0.9,
"epoch": 5 ,
"patience": 200 ,
"learning_rate_drop": 0.1,
"early_stop": 1000 ,
"train_csv_file": "" ,
"ckpt_save_path": "./cls_ckpt/best_cls_3d_20241112_2041/" ,
"ckpt_pretrain_path": "/df_lung/ai-project/cls_train/cls_ckpt/best_cls_3d_1025/08/cls_234567_2041/cls_234567_204120240105-0248.ckpt",
"ckpt_file": "cls_1_5001-6001/cls_1_5001-600120230825-0208.ckpt",
"training_filename": "./log/train/log_training_3d_234567_2041_20241112.log"
}
This diff is collapsed.
# -*- coding: utf-8 -*-
import random
import itertools
import numpy as np
def random_permute_key():
"""
Generates and randomly selects a permute key.
"""
return random.choice(list(generate_general_permute_keys()))
def generate_permute_keys():
"""
This function returns a set of "keys" that represent the 48 unique rotations &
reflections of a 3D matrix.
Each item of the set is a tuple:
((rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose)
"""
return set(itertools.product(
itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))
# def generate_general_indexes():
# indexes = [[0, 1, 2, 3, 4, 5, 6, 7],
# [1, 2, 3, 0, 5, 6, 7, 4],
# [2, 3, 0, 1, 6, 7, 4, 5],
# [3, 0, 1, 2, 7, 4, 5, 6],
# [4, 5, 6, 7, 0, 1, 2, 3],
# [5, 6, 7, 4, 1, 2, 3, 0],
# [6, 7, 4, 5, 2, 3, 0, 1],
# [7, 4, 5, 6, 3, 0, 1, 2]]
# return indexes
def generate_general_indexes():
indexes = [[0, 2],
[1, 3],
[2, 0],
[3, 1],
[4, 6],
[5, 7],
[6, 4],
[7, 5]]
return indexes
def generate_general_permute_keys():
rotate_y = 0
transpose = 0
keys = [((rotate_y, 0), 0, 0, 0, transpose),
((rotate_y, 1), 0, 0, 0, transpose),
((rotate_y, 0), 0, 1, 1, transpose),
((rotate_y, 1), 0, 1, 1, transpose),
((rotate_y, 0), 1, 0, 0, transpose),
((rotate_y, 1), 1, 0, 0, transpose),
((rotate_y, 0), 1, 1, 1, transpose),
((rotate_y, 1), 1, 1, 1, transpose)]
return keys
def permute_data(data, key):
if data is None or key == ((0, 0), 0, 0, 0, 0):
return data
data = np.copy(data)
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
if rotate_y != 0:
data = np.rot90(data, rotate_y, axes=(0, 2))
if rotate_z != 0:
data = np.rot90(data, rotate_z, axes=(1, 2))
if flip_z:
data = np.flip(data, axis=0)
if flip_y:
data = np.flip(data, axis=1)
if flip_x:
data = np.flip(data, axis=2)
if transpose:
add_axes = [i for i in range(3, len(data.shape))]
data = np.transpose(data, axes=[2, 1, 0] + add_axes)
return data
def reverse_permute_data(data, key):
if data is None or key == ((0, 0), 0, 0, 0, 0):
return data
data = np.copy(data)
key = reverse_permute_key(key)
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
if transpose:
add_axes = [i for i in range(3, len(data.shape))]
data = np.transpose(data, axes=[2, 1, 0] + add_axes)
if flip_x:
data = np.flip(data, axis=2)
if flip_y:
data = np.flip(data, axis=1)
if flip_z:
data = np.flip(data, axis=0)
if rotate_z != 0:
data = np.rot90(data, rotate_z, axes=(1, 2))
if rotate_y != 0:
data = np.rot90(data, rotate_y, axes=(0, 2))
return data
def reverse_permute_key(key):
rotation = tuple([-rotate for rotate in key[0]])
return rotation, key[1], key[2], key[3], key[4]
def augment(data, targets=None, mask=None, weight=None):
key = random_permute_key()
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
rotate_y = 0
transpose = 0
key = (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose
aug_data = permute_data(data, key)
aug_targets = targets
if targets is not None:
for i in range(len(aug_targets)):
aug_targets[i] = permute_data(aug_targets[i], key)
transform_target(aug_targets[i], key)
aug_mask = permute_data(mask, key)
aug_weight = permute_data(weight, key)
return aug_data, aug_targets, aug_mask, aug_weight
def transform_target(target, key):
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
if rotate_y:
z = target[..., 1].copy()
x = target[..., 3].copy()
target[..., 1] = -1.0 * x
target[..., 3] = z
if rotate_z:
y = target[..., 2].copy()
x = target[..., 3].copy()
target[..., 2] = -1.0 * x
target[..., 3] = y
if flip_z:
target[..., 1] = -1.0 * target[..., 1]
if flip_y:
target[..., 2] = -1.0 * target[..., 2]
if flip_x:
target[..., 3] = -1.0 * target[..., 3]
if transpose:
z = target[..., 1].copy()
x = target[..., 3].copy()
target[..., 1] = x
target[..., 3] = z
def data_test():
data = np.arange(40).reshape((2, 2, 2, 1, 5))
print('data', data)
history = []
for key in list(generate_general_permute_keys()):
print('key', key)
new_data = permute_data(data, key)
# print('new_data', new_data)
for old_data in history:
if np.all(old_data == new_data):
print('new_data is exist!!!')
break
history.append(new_data)
reverse_data = reverse_permute_data(new_data, key)
if not np.all(data == reverse_data):
print('error:', reverse_data)
def data_test2():
data = np.arange(60).reshape((3, 2, 2, 1, 5))
for i in range(10000):
key = random_permute_key()
print('key', key)
new_data = permute_data(data, key)
if new_data.shape != data.shape:
print('error:', new_data.shape)
def data_test3():
data = np.arange(40).reshape((2, 2, 2, 1, 5))
print('data', data)
key = random_permute_key()
(rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose = key
rotate_y = 0
rotate_z = 1
flip_z = 0
flip_y = 0
flip_x = 0
transpose = 0
key = (rotate_y, rotate_z), flip_z, flip_y, flip_x, transpose
new_data = permute_data(data, key)
transform_target(new_data, key)
print('new_data', new_data)
def data_test4():
data = np.arange(27).reshape((3, 3, 3))
print('data', data)
history = []
for key in list(generate_general_permute_keys()):
print('key', key)
new_data = permute_data(data, key)
print('new_data', new_data)
for old_data in history:
if np.all(old_data == new_data):
print('new_data is exist!!!')
break
history.append(new_data)
reverse_data = reverse_permute_data(new_data, key)
if not np.all(data == reverse_data):
print('error:', reverse_data)
if __name__ == '__main__':
data_test2()
# -*- coding: utf-8 -*-
import numpy as np
from collections import namedtuple
NoduleBox = namedtuple('NoduleBox',
['z', 'y', 'x', 'diameter', 'is_pos', 'probability'])
def nodule_raw2standard(nodule_boxes, raw_spacing, standard_spacing, start=(0, 0, 0)):
new_nodule_boxes = []
for nodule_box in nodule_boxes:
z, y, x, diameter, is_pos, probability = nodule_box
z = z * raw_spacing[0] / standard_spacing[0] - start[0]
y = y * raw_spacing[1] / standard_spacing[1] - start[1]
x = x * raw_spacing[2] / standard_spacing[2] - start[2]
new_nodule_boxes.append(
NoduleBox(z, y, x, diameter, is_pos, probability))
return new_nodule_boxes
def nodule_standard2raw(nodule_boxes, standard_spacing, raw_spacing, start=(0, 0, 0)):
new_nodule_boxes = []
for nodule_box in nodule_boxes:
z, y, x, diameter, is_pos, probability = nodule_box
z = (z + start[0]) * standard_spacing[0] / raw_spacing[0]
y = (y + start[1]) * standard_spacing[1] / raw_spacing[1]
x = (x + start[2]) * standard_spacing[2] / raw_spacing[2]
new_nodule_boxes.append(
NoduleBox(z, y, x, diameter, is_pos, probability))
return new_nodule_boxes
def iou2d(box_a, box_b):
box_a = np.asarray(box_a)
box_b = np.asarray(box_b)
a_start, a_end = box_a[0:2] - box_a[2:4] / 2, box_a[0:2] + box_a[2:4] / 2
b_start, b_end = box_b[0:2] - box_b[2:4] / 2, box_b[0:2] + box_b[2:4] / 2
y_overlap = max(0, min(a_end[0], b_end[0]) - max(a_start[0], b_start[0]))
x_overlap = max(0, min(a_end[1], b_end[1]) - max(a_start[1], b_start[1]))
intersection = y_overlap * x_overlap
union = box_a[2] * box_a[3] + box_b[2] * box_b[3] - intersection
return 1.0 * intersection / union
def iou3d(box_a, box_b, spacing):
# 半径需要在z, y, x各轴上换算成像素值
r_a = 0.5 * box_a.diameter / spacing
r_b = 0.5 * box_b.diameter / spacing
# starting index in each dimension
z_a_s, y_a_s, x_a_s = box_a.z - r_a[0], box_a.y - r_a[1], box_a.x - r_a[2]
z_b_s, y_b_s, x_b_s = box_b.z - r_b[0], box_b.y - r_b[1], box_b.x - r_b[2]
# ending index in each dimension
z_a_e, y_a_e, x_a_e = box_a.z + r_a[0], box_a.y + r_a[1], box_a.x + r_a[2]
z_b_e, y_b_e, x_b_e = box_b.z + r_b[0], box_b.y + r_b[1], box_b.x + r_b[2]
z_overlap = max(0, min(z_a_e, z_b_e) - max(z_a_s, z_b_s))
y_overlap = max(0, min(y_a_e, y_b_e) - max(y_a_s, y_b_s))
x_overlap = max(0, min(x_a_e, x_b_e) - max(x_a_s, x_b_s))
intersection = z_overlap * y_overlap * x_overlap
union = 8 * r_a[0] * r_a[1] * r_a[2] + 8 * r_b[0] * r_b[1] * r_b[2] - intersection
return 1.0 * intersection / union
def nms(nodule_boxes, spacing, iou_thred):
if nodule_boxes is None or len(nodule_boxes) == 0:
return nodule_boxes
nodule_boxes_nms = []
for nodule_box in nodule_boxes:
overlap = False
for box in nodule_boxes_nms:
if iou3d(box, nodule_box, spacing) >= iou_thred:
overlap = True
break
if not overlap:
nodule_boxes_nms.append(nodule_box)
return nodule_boxes_nms
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
from torch.utils.data import Dataset
import sys
import os
import numpy as np
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')
from cls_utils.data import load_data_csv, get_data_and_label
from cls_utils.utils import normalize, hu_value_to_uint8, hu_normalize
class SubjectDataset(Dataset):
def __init__(self, np_path, csv_path, is_train=True, is_2d=False, augment=False, permute=False):
self._npy_path = np_path
self._csv_path = csv_path
#self._cfg = cfg
self._is_train = is_train
self._is_2d = is_2d
self.augment = augment
self.permute = permute
self._preprocess()
def _preprocess(self):
self.subject_ids = load_data_csv(self._csv_path)
def __len__(self):
return self.subject_ids.shape[0]
def _get_data(self, i):
idx = i % len(self.subject_ids)
#获取当前指定索引对应的npy文件的路径
subject_id = self.subject_ids.iloc[idx, 0]
label = int(self.subject_ids.iloc[idx, 1])
data = get_data_and_label(self._npy_path, subject_id, self._is_2d, self.augment, self.permute)
#print(data.shape)
return data, np.asarray([label])
def __getitem__(self, idx):
data, label = self._get_data(idx)
if data is None or label is None:
data, label = self._get_data(0)
data = hu_normalize(data)
#data = normalize(hu_value_to_uint8(data))
return torch.from_numpy(data.copy()), torch.tensor(label).type(torch.float32)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
import torch.nn as nn
class BCEFocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.5, reduction='elementwise_mean') -> None:
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, _input, target):
pt = torch.sigmoid(_input)
alpha = self.alpha
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class Normalize(nn.Module):
def __init__(self):
super(Normalize, self).__init__()
def forward(self, data, min_value=-1000, max_value=600):
new_data = data
new_data[new_data < min_value] = min_value
new_data[new_data > max_value] = max_value
# normalize to [-1, 1]
new_data = 2.0 * (new_data - min_value) / (max_value - min_value) - 1
return new_data
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment