dist_utils.py 2.54 KB
'''
@Author: Jiamin Ren
@Date: 2020-04-29 09:56:57
'''
import os
import torch
import torch.multiprocessing as mp
import torch.distributed as dist

__all__ = [
    'init_dist','init_dist_slurm', 'broadcast_params','average_gradients']

def init_dist(backend='nccl',
              master_ip='127.0.0.1',
              port=29500):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')

    os.environ['MASTER_ADDR'] = master_ip
    os.environ['MASTER_PORT'] = str(port)
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    num_gpus = torch.cuda.device_count()
    local_rank = os.environ['LOCAL_RANK'] 
    deviceid = eval(local_rank) % num_gpus
    torch.cuda.set_device(deviceid)

    print(fr'dist settings: local_rank {local_rank}, rank {rank}, worldsize {world_size}, gpus {num_gpus}, deviceid {deviceid}')

    dist.init_process_group(backend=backend)
    return rank, world_size

def init_dist_slurm(backend='nccl', port=29500):
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn')
    proc_id = int(os.environ['SLURM_PROCID'])
    ntasks = int(os.environ['SLURM_NTASKS'])
    node_list = os.environ['SLURM_NODELIST']
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(proc_id%num_gpus)

    if '[' in node_list:
        beg = node_list.find('[')
        pos1 = node_list.find('-', beg)
        if pos1 < 0:
            pos1 = 1000
        pos2 = node_list.find(',', beg)
        if pos2 < 0:
            pos2 = 1000
        node_list = node_list[:min(pos1,pos2)].replace('[', '')
    addr = node_list[8:].replace('-', '.')

    os.environ['MASTER_PORT'] = str(port) 
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend='nccl')

    rank = dist.get_rank()
    world_size = dist.get_world_size()
    return rank, world_size
    
def average_gradients(model):
    """ Gradient averaging. """
    size = float(dist.get_world_size())
    # for param in model.parameters():
    for name, param in model.named_parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            param.grad.data /= size
        else:
            print(fr'param {name} grad is None')

def broadcast_params(model):
    for p in model.state_dict().values():
        dist.broadcast(p, 0)

def average_variable(var):
    size = float(dist.get_world_size())
    dist.all_reduce(var, op=dist.ReduceOp.SUM)
    var /= size
    return var