mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
18 lines
418 B
Python
18 lines
418 B
Python
import os
|
|
|
|
import torch
|
|
from torch import distributed as dist
|
|
|
|
|
|
def setup_dist(rank, world_size, master_port=None):
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = (
|
|
"12354" if master_port is None else str(master_port)
|
|
)
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
torch.cuda.set_device(rank)
|
|
|
|
|
|
def cleanup_dist():
|
|
dist.destroy_process_group()
|