import os import pynvml pynvml.nvmlInit() def set_gpu(num_gpu, used_percent=0.7): pynvml.nvmlInit() print("Found %d GPU(s)" % pynvml.nvmlDeviceGetCount()) available_gpus = [] for index in range(pynvml.nvmlDeviceGetCount()): handle = pynvml.nvmlDeviceGetHandleByIndex(index) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) used = meminfo.used / meminfo.total if used < used_percent: available_gpus.append(index) # 10.100.37.5 6th gpu throws exception with cuda illegal address # available_gpus.remove(0) if len(available_gpus) >= num_gpu: gpus = ','.join(str(e) for e in available_gpus[:num_gpu]) os.environ["CUDA_VISIBLE_DEVICES"] = gpus print("Using GPU %s" % gpus) else: raise ValueError("No GPUs available, current number of available GPU is %d, requested for %d GPU(s)" % ( len(available_gpus), num_gpu))