From 7a0c7b747861249776923d3a0a53f1d4b428df5b Mon Sep 17 00:00:00 2001 From: czl66 <1479822106@qq.com> Date: Tue, 24 Dec 2024 11:59:01 +0800 Subject: [PATCH] Modified aishell/ASR/conformer_ctc/train.py, which implemented multi-machine DDP. --- egs/aishell/ASR/conformer_ctc/train.py | 33 ++++++++++++-------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index c2cbe6e3b..df52cffca 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -22,9 +22,9 @@ from pathlib import Path from shutil import copyfile from typing import Optional, Tuple +import os import k2 import torch -import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AishellAsrDataModule from conformer import Conformer @@ -543,13 +543,9 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def run(rank, world_size, args): +def run(world_size, args): """ Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: @@ -560,13 +556,14 @@ def run(rank, world_size, args): fix_random_seed(params.seed) if world_size > 1: - setup_dist(rank, world_size, params.master_port) - + setup_dist(use_ddp_launch=True, master_addr=params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - logging.info(params) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if local_rank == 0: + logging.info(params) - if args.tensorboard and rank == 0: + if args.tensorboard and local_rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None @@ -577,7 +574,7 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", rank) + device = torch.device("cuda", local_rank) graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, @@ -603,7 +600,8 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + torch.distributed.barrier() # Ensure all processes have the same model parameters + model = DDP(model, device_ids=[local_rank]) optimizer = Noam( model.parameters(), @@ -629,7 +627,7 @@ def run(rank, world_size, args): tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: + if local_rank == 0: logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) params.cur_epoch = epoch @@ -644,12 +642,14 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, ) + if world_size > 1: + torch.distributed.barrier() save_checkpoint( params=params, model=model, optimizer=optimizer, - rank=rank, + rank=local_rank, ) logging.info("Done!") @@ -668,10 +668,7 @@ def main(): world_size = args.world_size assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) + run(world_size=world_size, args=args) torch.set_num_threads(1)