WIP: Support multi-node multi-GPU training.

This commit is contained in:
Fangjun Kuang 2021-09-30 01:08:09 +08:00
parent 707d7017a7
commit c688161f44
4 changed files with 96 additions and 20 deletions

View File

@ -383,7 +383,7 @@ def decode_one_batch(
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans

View File

@ -39,7 +39,13 @@ from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -54,6 +60,13 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="True if using multi-node multi-GPU.",
)
parser.add_argument(
"--world-size",
type=int,
@ -92,6 +105,23 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""It contains language related input files such as lexicon.txt
""",
)
return parser
@ -106,12 +136,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -621,9 +645,17 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
if args.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(rank, world_size, params.master_port, args.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
@ -640,7 +672,8 @@ def run(rank, world_size, args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
@ -665,7 +698,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[local_rank])
optimizer = Noam(
model.parameters(),
@ -726,9 +759,21 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
if args.use_multi_node:
# for multi-node multi-GPU training
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
print(f"rank: {rank}, world_size: {world_size}")
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:

View File

@ -41,6 +41,7 @@ dl_dir=$PWD/download
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
5000
500
)
# All files generated by this script are saved in "data".
@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir $lang_dir
done
fi
cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -21,14 +21,46 @@ import torch
from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None):
def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
"""
rank and world_size are used only if is_multi_node is False.
"""
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
if is_multi_node is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")
def cleanup_dist():
dist.destroy_process_group()
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))