1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
@Author: Jiamin Ren
@Date: 2020-04-29 09:56:57
'''
import os
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
__all__ = [
'init_dist','init_dist_slurm', 'broadcast_params','average_gradients']
def init_dist(backend='nccl',
master_ip='127.0.0.1',
port=29500):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
os.environ['MASTER_ADDR'] = master_ip
os.environ['MASTER_PORT'] = str(port)
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
num_gpus = torch.cuda.device_count()
local_rank = os.environ['LOCAL_RANK']
deviceid = eval(local_rank) % num_gpus
torch.cuda.set_device(deviceid)
print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}')
dist.init_process_group(backend=backend)
return rank, world_size
def init_dist_slurm(backend='nccl', port=29500):
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id%num_gpus)
if '[' in node_list:
beg = node_list.find('[')
pos1 = node_list.find('-', beg)
if pos1 < 0:
pos1 = 1000
pos2 = node_list.find(',', beg)
if pos2 < 0:
pos2 = 1000
node_list = node_list[:min(pos1,pos2)].replace('[', '')
addr = node_list[8:].replace('-', '.')
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
# for param in model.parameters():
for name, param in model.named_parameters():
if param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
else:
print(fr'param {name} grad is None')
def broadcast_params(model):
for p in model.state_dict().values():
dist.broadcast(p, 0)
def average_variable(var):
size = float(dist.get_world_size())
dist.all_reduce(var, op=dist.ReduceOp.SUM)
var /= size
return var