generated from thinkode/modelRepository
Initial commit and v1.0
This commit is contained in:
100
demucs/distrib.py
Normal file
100
demucs/distrib.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Distributed training utilities.
|
||||
"""
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from dora import distrib as dora_distrib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
rank = 0
|
||||
world_size = 1
|
||||
|
||||
|
||||
def init():
|
||||
global rank, world_size
|
||||
if not torch.distributed.is_initialized():
|
||||
dora_distrib.init()
|
||||
rank = dora_distrib.rank()
|
||||
world_size = dora_distrib.world_size()
|
||||
|
||||
|
||||
def average(metrics, count=1.):
|
||||
if isinstance(metrics, dict):
|
||||
keys, values = zip(*sorted(metrics.items()))
|
||||
values = average(values, count)
|
||||
return dict(zip(keys, values))
|
||||
if world_size == 1:
|
||||
return metrics
|
||||
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
|
||||
tensor *= count
|
||||
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
|
||||
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
|
||||
|
||||
|
||||
def wrap(model):
|
||||
if world_size == 1:
|
||||
return model
|
||||
else:
|
||||
return DistributedDataParallel(
|
||||
model,
|
||||
# find_unused_parameters=True,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
output_device=torch.cuda.current_device())
|
||||
|
||||
|
||||
def barrier():
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def share(obj=None, src=0):
|
||||
if world_size == 1:
|
||||
return obj
|
||||
size = torch.empty(1, device='cuda', dtype=torch.long)
|
||||
if rank == src:
|
||||
dump = pickle.dumps(obj)
|
||||
size[0] = len(dump)
|
||||
torch.distributed.broadcast(size, src=src)
|
||||
# size variable is now set to the length of pickled obj in all processes
|
||||
|
||||
if rank == src:
|
||||
buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
|
||||
else:
|
||||
buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
|
||||
torch.distributed.broadcast(buffer, src=src)
|
||||
# buffer variable is now set to pickled obj in all processes
|
||||
|
||||
if rank != src:
|
||||
obj = pickle.loads(buffer.cpu().numpy().tobytes())
|
||||
logger.debug(f"Shared object of size {len(buffer)}")
|
||||
return obj
|
||||
|
||||
|
||||
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
|
||||
"""
|
||||
Create a dataloader properly in case of distributed training.
|
||||
If a gradient is going to be computed you must set `shuffle=True`.
|
||||
"""
|
||||
if world_size == 1:
|
||||
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
||||
|
||||
if shuffle:
|
||||
# train means we will compute backward, we use DistributedSampler
|
||||
sampler = DistributedSampler(dataset)
|
||||
# We ignore shuffle, DistributedSampler already shuffles
|
||||
return klass(dataset, *args, **kwargs, sampler=sampler)
|
||||
else:
|
||||
# We make a manual shard, as DistributedSampler otherwise replicate some examples
|
||||
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
|
||||
return klass(dataset, *args, shuffle=shuffle, **kwargs)
|
||||
Reference in New Issue
Block a user