From c688161f44057db6e38bc98b91bc849cb70ac40c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 30 Sep 2021 01:08:09 +0800 Subject: [PATCH] WIP: Support multi-node multi-GPU training. --- egs/librispeech/ASR/conformer_ctc/decode.py | 2 +- egs/librispeech/ASR/conformer_ctc/train.py | 65 +++++++++++++++++---- egs/librispeech/ASR/prepare.sh | 3 +- icefall/dist.py | 46 ++++++++++++--- 4 files changed, 96 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 5a83dd39c..1b5da7cf3 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -383,7 +383,7 @@ def decode_one_batch( ans[lm_scale_str] = hyps else: for lm_scale in lm_scale_list: - ans[lm_scale_str] = [[] * lattice.shape[0]] + ans[f"{lm_scale}"] = [[] * lattice.shape[0]] return ans diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a..f7ac47076 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -39,7 +39,13 @@ from transformer import Noam from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist +from icefall.dist import ( + cleanup_dist, + get_local_rank, + get_rank, + get_world_size, + setup_dist, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -54,6 +60,13 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--use-multi-node", + type=str2bool, + default=False, + help="True if using multi-node multi-GPU.", + ) + parser.add_argument( "--world-size", type=int, @@ -92,6 +105,23 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved. + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="""It contains language related input files such as lexicon.txt + """, + ) + return parser @@ -106,12 +136,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -621,9 +645,17 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) + if args.use_multi_node: + local_rank = get_local_rank() + else: + local_rank = rank + logging.info( + f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}" + ) + fix_random_seed(42) if world_size > 1: - setup_dist(rank, world_size, params.master_port) + setup_dist(rank, world_size, params.master_port, args.use_multi_node) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") @@ -640,7 +672,8 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", rank) + device = torch.device("cuda", local_rank) + logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") graph_compiler = BpeCtcTrainingGraphCompiler( params.lang_dir, @@ -665,7 +698,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + model = DDP(model, device_ids=[local_rank]) optimizer = Noam( model.parameters(), @@ -726,9 +759,21 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + if args.use_multi_node: + # for multi-node multi-GPU training + rank = get_rank() + world_size = get_world_size() + args.world_size = world_size + print(f"rank: {rank}, world_size: {world_size}") + run(rank=rank, world_size=world_size, args=args) + return world_size = args.world_size assert world_size >= 1 + if world_size > 1: mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) else: diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f6..dd3f1085a 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,6 +41,7 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 + 500 ) # All files generated by this script are saved in "data". @@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_5000 lang_bpe diff --git a/icefall/dist.py b/icefall/dist.py index 203c7c563..9dc903177 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,14 +21,46 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = ( - "12354" if master_port is None else str(master_port) - ) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) +def setup_dist(rank, world_size, master_port=None, is_multi_node=False): + """ + rank and world_size are used only if is_multi_node is False. + """ + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) + + if is_multi_node is False: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + else: + dist.init_process_group("nccl") def cleanup_dist(): dist.destroy_process_group() + + +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.rank() + else: + return 1 + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0))