import torch if __name__ == '__main__': assert torch.cuda.is_available() print('num_gpus in check: ', torch.cuda.device_count())