Modified aishell/ASR/conformer_ctc/train.py, which implemented multi-machine DDP.

This commit is contained in:
czl66 2024-12-24 11:59:01 +08:00
parent 3e4da5f781
commit 7a0c7b7478

View File

@ -22,9 +22,9 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
import os
import k2 import k2
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from conformer import Conformer from conformer import Conformer
@ -543,13 +543,9 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def run(rank, world_size, args): def run(world_size, args):
""" """
Args: Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size: world_size:
Number of GPUs for DDP training. Number of GPUs for DDP training.
args: args:
@ -560,13 +556,14 @@ def run(rank, world_size, args):
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(use_ddp_launch=True, master_addr=params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
logging.info(params) local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
logging.info(params)
if args.tensorboard and rank == 0: if args.tensorboard and local_rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else: else:
tb_writer = None tb_writer = None
@ -577,7 +574,7 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", local_rank)
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon, lexicon=lexicon,
@ -603,7 +600,8 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) torch.distributed.barrier() # Ensure all processes have the same model parameters
model = DDP(model, device_ids=[local_rank])
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
@ -629,7 +627,7 @@ def run(rank, world_size, args):
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if local_rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch params.cur_epoch = epoch
@ -644,12 +642,14 @@ def run(rank, world_size, args):
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
) )
if world_size > 1:
torch.distributed.barrier()
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
rank=rank, rank=local_rank,
) )
logging.info("Done!") logging.info("Done!")
@ -668,10 +668,7 @@ def main():
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: run(world_size=world_size, args=args)
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1) torch.set_num_threads(1)