mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Modified aishell/ASR/conformer_ctc/train.py, which implemented multi-machine DDP.
This commit is contained in:
parent
3e4da5f781
commit
7a0c7b7478
@ -22,9 +22,9 @@ from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import os
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from conformer import Conformer
|
||||
@ -543,13 +543,9 @@ def train_one_epoch(
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
def run(world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
@ -560,13 +556,14 @@ def run(rank, world_size, args):
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_dist(use_ddp_launch=True, master_addr=params.master_port)
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if local_rank == 0:
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
if args.tensorboard and local_rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
@ -577,7 +574,7 @@ def run(rank, world_size, args):
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
@ -603,7 +600,8 @@ def run(rank, world_size, args):
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
torch.distributed.barrier() # Ensure all processes have the same model parameters
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
@ -629,7 +627,7 @@ def run(rank, world_size, args):
|
||||
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
if local_rank == 0:
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
@ -644,12 +642,14 @@ def run(rank, world_size, args):
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
rank=local_rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
@ -668,10 +668,7 @@ def main():
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
run(world_size=world_size, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user