Support DDP training.

This commit is contained in:
Fangjun Kuang 2021-07-25 21:40:09 +08:00
parent 4a66712406
commit 8055bf31a0
3 changed files with 105 additions and 11 deletions

View File

@ -12,3 +12,11 @@ cd $PWD/..
./tdnn_lstm_ctc/train.py ./tdnn_lstm_ctc/train.py
``` ```
If you have 4 GPUs and want to use GPU 1 and GPU 3 for DDP training,
you can do the following:
```
export CUDA_VISIBLE_DEVICES="1,3"
./tdnn_lstm_ctc/train.py --world-size=2
```

View File

@ -10,9 +10,13 @@ from typing import Optional
import k2 import k2
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from lhotse.utils import fix_random_seed
from model import TdnnLstm from model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_ from torch.nn.utils import clip_grad_value_
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -20,9 +24,15 @@ from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, encode_supervisions, setup_logger from icefall.utils import (
AttributeDict,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():
@ -43,6 +53,14 @@ def get_parser():
default=12354, default=12354,
help="Master port to use for DDP training.", help="Master port to use for DDP training.",
) )
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
# TODO: add extra arguments and support DDP training. # TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add # Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished. # DDP training once single GPU training is finished.
@ -186,6 +204,7 @@ def save_checkpoint(
model: nn.Module, model: nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler, scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -202,6 +221,7 @@ def save_checkpoint(
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
rank=rank,
) )
if params.best_train_epoch == params.cur_epoch: if params.best_train_epoch == params.cur_epoch:
@ -290,6 +310,7 @@ def compute_validation_loss(
model: nn.Module, model: nn.Module,
graph_compiler: CtcTrainingGraphCompiler, graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None: ) -> None:
"""Run the validation process. The validation loss """Run the validation process. The validation loss
is saved in `params.valid_loss`. is saved in `params.valid_loss`.
@ -312,6 +333,13 @@ def compute_validation_loss(
tot_loss += loss_cpu tot_loss += loss_cpu
tot_frames += params.valid_frames tot_frames += params.valid_frames
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
params.valid_loss = tot_loss / tot_frames params.valid_loss = tot_loss / tot_frames
if params.valid_loss < params.best_valid_loss: if params.valid_loss < params.best_valid_loss:
@ -327,6 +355,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None: ) -> None:
"""Train the model for one epoch. """Train the model for one epoch.
@ -349,6 +378,8 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
""" """
model.train() model.train()
@ -394,11 +425,12 @@ def train_one_epoch(
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size,
) )
model.train() model.train()
logging.info( logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss}, " f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f"best valid loss: {params.best_valid_loss:.4f} " f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}" f"best valid epoch: {params.best_valid_epoch}"
) )
@ -409,26 +441,40 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def main(): def run(rank, world_size, args):
parser = get_parser() """
LibriSpeechAsrDataModule.add_arguments(parser) Args:
args = parser.parse_args() rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
logging.info(params) logging.info(params)
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens) max_phone_id = max(lexicon.tokens)
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", rank)
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
@ -438,6 +484,8 @@ def main():
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
model.to(device) model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = optim.AdamW( optimizer = optim.AdamW(
model.parameters(), model.parameters(),
@ -478,15 +526,36 @@ def main():
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size,
) )
scheduler.step() scheduler.step()
save_checkpoint( save_checkpoint(
params=params, model=model, optimizer=optimizer, scheduler=scheduler params=params,
model=model,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
) )
logging.info("Done!") logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__": if __name__ == "__main__":

17
icefall/dist.py Normal file
View File

@ -0,0 +1,17 @@
import os
import torch
from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup_dist():
dist.destroy_process_group()