mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
127 lines
4.0 KiB
Python
127 lines
4.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
|
|
"""Torch distributed utilities."""
|
|
from typing import Dict, Iterable, List
|
|
|
|
import torch
|
|
from torch import distributed as dist
|
|
|
|
|
|
def rank():
|
|
if dist.is_initialized():
|
|
return dist.get_rank()
|
|
else:
|
|
return 0
|
|
|
|
|
|
def world_size():
|
|
if dist.is_initialized():
|
|
return dist.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def is_distributed():
|
|
return world_size() > 1
|
|
|
|
|
|
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM):
|
|
if is_distributed():
|
|
return dist.all_reduce(tensor, op)
|
|
|
|
|
|
def _is_complex_or_float(tensor):
|
|
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
|
|
|
|
|
def _check_number_of_params(params: List[torch.Tensor]):
|
|
# utility function to check that the number of params in all workers is the same,
|
|
# and thus avoid a deadlock with distributed all reduce.
|
|
if not is_distributed() or not params:
|
|
return
|
|
# print('params[0].device ', params[0].device)
|
|
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
|
all_reduce(tensor)
|
|
if tensor.item() != len(params) * world_size():
|
|
# If not all the workers have the same number, for at least one of them,
|
|
# this inequality will be verified.
|
|
raise RuntimeError(
|
|
f"Mismatch in number of params: ours is {len(params)}, "
|
|
"at least one worker has a different one."
|
|
)
|
|
|
|
|
|
def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0):
|
|
"""Broadcast the tensors from the given parameters to all workers.
|
|
This can be used to ensure that all workers have the same model to start with.
|
|
"""
|
|
if not is_distributed():
|
|
return
|
|
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
|
_check_number_of_params(tensors)
|
|
handles = []
|
|
for tensor in tensors:
|
|
# src = int(rank()) # added code
|
|
handle = dist.broadcast(tensor.data, src=src, async_op=True)
|
|
handles.append(handle)
|
|
for handle in handles:
|
|
handle.wait()
|
|
|
|
|
|
def sync_buffer(buffers, average=True):
|
|
"""
|
|
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
|
"""
|
|
if not is_distributed():
|
|
return
|
|
handles = []
|
|
for buffer in buffers:
|
|
if torch.is_floating_point(buffer.data):
|
|
if average:
|
|
handle = dist.all_reduce(
|
|
buffer.data, op=dist.ReduceOp.SUM, async_op=True
|
|
)
|
|
else:
|
|
handle = dist.broadcast(buffer.data, src=0, async_op=True)
|
|
handles.append((buffer, handle))
|
|
for buffer, handle in handles:
|
|
handle.wait()
|
|
if average:
|
|
buffer.data /= world_size
|
|
|
|
|
|
def sync_grad(params):
|
|
"""
|
|
Simpler alternative to DistributedDataParallel, that doesn't rely
|
|
on any black magic. For simple models it can also be as fast.
|
|
Just call this on your model parameters after the call to backward!
|
|
"""
|
|
if not is_distributed():
|
|
return
|
|
handles = []
|
|
for p in params:
|
|
if p.grad is not None:
|
|
handle = dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM, async_op=True)
|
|
handles.append((p, handle))
|
|
for p, handle in handles:
|
|
handle.wait()
|
|
p.grad.data /= world_size()
|
|
|
|
|
|
def average_metrics(metrics: Dict[str, float], count=1.0):
|
|
"""Average a dictionary of metrics across all workers, using the optional
|
|
`count` as unormalized weight.
|
|
"""
|
|
if not is_distributed():
|
|
return metrics
|
|
keys, values = zip(*metrics.items())
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
|
tensor *= count
|
|
all_reduce(tensor)
|
|
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
|
return dict(zip(keys, averaged))
|