# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # !/usr/bin/env python import os import torch from setuptools import find_packages from setuptools import setup from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension def get_extensions(): this_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) extensions_dir = os.path.join(this_dir, "csrc") extension = CppExtension custom_ops_sources = [os.path.join(extensions_dir, "custom_ops", "custom_ops.cpp"), os.path.join(extensions_dir, "cpu", "nms_3d_cpu.cpp")] custom_ops_sources_cuda = [os.path.join(extensions_dir, "cuda", "nms_3d.cu"), os.path.join(extensions_dir, "cuda", "overlap_3d.cu"), os.path.join(extensions_dir, "cuda", "ROIAlign_3d_cuda.cu"), os.path.join(extensions_dir, "cuda", "deformable_conv_3d_cuda.cu"), os.path.join(extensions_dir, "cuda", "modulated_deformable_conv_3d_cuda.cu"), ] custom_ops_libraries = [] extra_compile_args = {"cxx": []} define_macros = [] if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": extension = CUDAExtension custom_ops_sources += custom_ops_sources_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ # Whether a short float (float16,fp16) is supported. "-DCUDA_HAS_FP16=1", # We are using these flags to use the internal PyTorch half operations instead of the one from the CUDA libraries. "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] include_dirs = [extensions_dir] ext_modules = [ extension( "custom_ops", sources=custom_ops_sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, libraries=custom_ops_libraries ) ] return ext_modules setup( name="custom_ops", version="0.2", author="gupc", url="http://git.do.proxima-ai.com/cn.aitrox.ai/grouplung", description="nodule detection in pytorch", packages=find_packages(exclude=("configs", "tests",)), # install_requires=requirements, ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension.with_options(no_python_abi_suffix=True)} )