diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md index ce6d77294..401f3e319 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/README.md +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/README.md @@ -12,3 +12,11 @@ cd $PWD/.. ./tdnn_lstm_ctc/train.py ``` + +If you have 4 GPUs and want to use GPU 1 and GPU 3 for DDP training, +you can do the following: + +``` +export CUDA_VISIBLE_DEVICES="1,3" +./tdnn_lstm_ctc/train.py --world-size=2 +``` diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index fe50130a2..78f276a13 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -10,9 +10,13 @@ from typing import Optional import k2 import torch +import torch.distributed as dist +import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from lhotse.utils import fix_random_seed from model import TdnnLstm +from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_value_ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter @@ -20,9 +24,15 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dataset.librispeech import LibriSpeechAsrDataModule +from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, encode_supervisions, setup_logger +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -43,6 +53,14 @@ def get_parser(): default=12354, help="Master port to use for DDP training.", ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + # TODO: add extra arguments and support DDP training. # Currently, only single GPU training is implemented. Will add # DDP training once single GPU training is finished. @@ -186,6 +204,7 @@ def save_checkpoint( model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler, + rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -202,6 +221,7 @@ def save_checkpoint( params=params, optimizer=optimizer, scheduler=scheduler, + rank=rank, ) if params.best_train_epoch == params.cur_epoch: @@ -290,6 +310,7 @@ def compute_validation_loss( model: nn.Module, graph_compiler: CtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, ) -> None: """Run the validation process. The validation loss is saved in `params.valid_loss`. @@ -312,6 +333,13 @@ def compute_validation_loss( tot_loss += loss_cpu tot_frames += params.valid_frames + if world_size > 1: + s = torch.tensor([tot_loss, tot_frames], device=loss.device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + s = s.cpu().tolist() + tot_loss = s[0] + tot_frames = s[1] + params.valid_loss = tot_loss / tot_frames if params.valid_loss < params.best_valid_loss: @@ -327,6 +355,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, ) -> None: """Train the model for one epoch. @@ -349,6 +378,8 @@ def train_one_epoch( Dataloader for the validation dataset. tb_writer: Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. """ model.train() @@ -394,11 +425,12 @@ def train_one_epoch( model=model, graph_compiler=graph_compiler, valid_dl=valid_dl, + world_size=world_size, ) model.train() logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss}, " - f"best valid loss: {params.best_valid_loss:.4f} " + f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) @@ -409,26 +441,40 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ params = get_params() params.update(vars(args)) + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") logging.info(params) - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None lexicon = Lexicon(params.lang_dir) max_phone_id = max(lexicon.tokens) device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", 0) + device = torch.device("cuda", rank) graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) @@ -438,6 +484,8 @@ def main(): subsampling_factor=params.subsampling_factor, ) model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) optimizer = optim.AdamW( model.parameters(), @@ -478,15 +526,36 @@ def main(): train_dl=train_dl, valid_dl=valid_dl, tb_writer=tb_writer, + world_size=world_size, ) scheduler.step() save_checkpoint( - params=params, model=model, optimizer=optimizer, scheduler=scheduler + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, ) logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) if __name__ == "__main__": diff --git a/icefall/dist.py b/icefall/dist.py new file mode 100644 index 000000000..d314d2a43 --- /dev/null +++ b/icefall/dist.py @@ -0,0 +1,17 @@ +import os + +import torch +from torch import distributed as dist + + +def setup_dist(rank, world_size, master_port=None): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def cleanup_dist(): + dist.destroy_process_group()