From 774f6643cd661f853b9c86d112ea3ecd0d27cf81 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 23 Nov 2021 11:58:25 +0800 Subject: [PATCH] Start training LM for LibriSpeech. --- egs/librispeech/ASR/prepare.sh | 3 --- icefall/dist.py | 46 ++++++++++++++++++++++++++++------ icefall/utils.py | 16 +++++++++--- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 50aa53ca2..76359a52c 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -42,8 +42,6 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 - 2000 - 1000 500 ) @@ -288,7 +286,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then cat $f | cut -d " " -f 2- done > $out_dir/test.txt fi - exit 0 lang_dir=data/lang_bpe_${vocab_size} ./local/prepare_lm_training_data.py \ diff --git a/icefall/dist.py b/icefall/dist.py index 203c7c563..6334f9c13 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,14 +21,46 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = ( - "12354" if master_port is None else str(master_port) - ) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) +def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): + """ + rank and world_size are used only if use_ddp_launch is False. + """ + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) + + if use_ddp_launch is False: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + else: + dist.init_process_group("nccl") def cleanup_dist(): dist.destroy_process_group() + + +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.rank() + else: + return 1 + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0)) diff --git a/icefall/utils.py b/icefall/utils.py index 1d4aabd72..3f1547bee 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -91,7 +91,11 @@ def str2bool(v): def setup_logger( - log_filename: Pathlike, log_level: str = "info", use_console: bool = True + log_filename: Pathlike, + log_level: str = "info", + rank: int = 0, + world_size: int = 1, + use_console: bool = True, ) -> None: """Setup log level. @@ -101,12 +105,16 @@ def setup_logger( log_level: The log level to use, e.g., "debug", "info", "warning", "error", "critical" + rank: + Rank of this node in DDP training. + world_size: + Number of nodes in DDP training. + use_console: + True to also print logs to console. """ now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - if dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() + if world_size > 1: formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: