mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Support DDP training.
This commit is contained in:
parent
4a66712406
commit
8055bf31a0
@ -12,3 +12,11 @@ cd $PWD/..
|
|||||||
|
|
||||||
./tdnn_lstm_ctc/train.py
|
./tdnn_lstm_ctc/train.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you have 4 GPUs and want to use GPU 1 and GPU 3 for DDP training,
|
||||||
|
you can do the following:
|
||||||
|
|
||||||
|
```
|
||||||
|
export CUDA_VISIBLE_DEVICES="1,3"
|
||||||
|
./tdnn_lstm_ctc/train.py --world-size=2
|
||||||
|
```
|
||||||
|
@ -10,9 +10,13 @@ from typing import Optional
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
from model import TdnnLstm
|
from model import TdnnLstm
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_value_
|
from torch.nn.utils import clip_grad_value_
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
@ -20,9 +24,15 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
encode_supervisions,
|
||||||
|
setup_logger,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -43,6 +53,14 @@ def get_parser():
|
|||||||
default=12354,
|
default=12354,
|
||||||
help="Master port to use for DDP training.",
|
help="Master port to use for DDP training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: add extra arguments and support DDP training.
|
# TODO: add extra arguments and support DDP training.
|
||||||
# Currently, only single GPU training is implemented. Will add
|
# Currently, only single GPU training is implemented. Will add
|
||||||
# DDP training once single GPU training is finished.
|
# DDP training once single GPU training is finished.
|
||||||
@ -186,6 +204,7 @@ def save_checkpoint(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
||||||
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
|
|
||||||
@ -202,6 +221,7 @@ def save_checkpoint(
|
|||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.best_train_epoch == params.cur_epoch:
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
@ -290,6 +310,7 @@ def compute_validation_loss(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
graph_compiler: CtcTrainingGraphCompiler,
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the validation process. The validation loss
|
"""Run the validation process. The validation loss
|
||||||
is saved in `params.valid_loss`.
|
is saved in `params.valid_loss`.
|
||||||
@ -312,6 +333,13 @@ def compute_validation_loss(
|
|||||||
tot_loss += loss_cpu
|
tot_loss += loss_cpu
|
||||||
tot_frames += params.valid_frames
|
tot_frames += params.valid_frames
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
||||||
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
|
s = s.cpu().tolist()
|
||||||
|
tot_loss = s[0]
|
||||||
|
tot_frames = s[1]
|
||||||
|
|
||||||
params.valid_loss = tot_loss / tot_frames
|
params.valid_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
if params.valid_loss < params.best_valid_loss:
|
if params.valid_loss < params.best_valid_loss:
|
||||||
@ -327,6 +355,7 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model for one epoch.
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
@ -349,6 +378,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
@ -394,11 +425,12 @@ def train_one_epoch(
|
|||||||
model=model,
|
model=model,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss}, "
|
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
||||||
f"best valid loss: {params.best_valid_loss:.4f} "
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -409,26 +441,40 @@ def train_one_epoch(
|
|||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def run(rank, world_size, args):
|
||||||
parser = get_parser()
|
"""
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
Args:
|
||||||
args = parser.parse_args()
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`, which is
|
||||||
|
passed automatically by `mp.spawn()` in :func:`main`.
|
||||||
|
The node with rank 0 is responsible for saving checkpoint.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
fix_random_seed(42)
|
||||||
|
if world_size > 1:
|
||||||
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
max_phone_id = max(lexicon.tokens)
|
max_phone_id = max(lexicon.tokens)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||||
|
|
||||||
@ -438,6 +484,8 @@ def main():
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
if world_size > 1:
|
||||||
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = optim.AdamW(
|
optimizer = optim.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
@ -478,15 +526,36 @@ def main():
|
|||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params, model=model, optimizer=optimizer, scheduler=scheduler
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
17
icefall/dist.py
Normal file
17
icefall/dist.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def setup_dist(rank, world_size, master_port=None):
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = (
|
||||||
|
"12354" if master_port is None else str(master_port)
|
||||||
|
)
|
||||||
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_dist():
|
||||||
|
dist.destroy_process_group()
|
Loading…
x
Reference in New Issue
Block a user