WIP: Support multi-node multi-GPU training.

This commit is contained in:
Fangjun Kuang 2021-09-30 01:08:09 +08:00
parent 707d7017a7
commit c688161f44
4 changed files with 96 additions and 20 deletions

View File

@ -383,7 +383,7 @@ def decode_one_batch(
ans[lm_scale_str] = hyps ans[lm_scale_str] = hyps
else: else:
for lm_scale in lm_scale_list: for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]] ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans return ans

View File

@ -39,7 +39,13 @@ from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -54,6 +60,13 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="True if using multi-node multi-GPU.",
)
parser.add_argument( parser.add_argument(
"--world-size", "--world-size",
type=int, type=int,
@ -92,6 +105,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""It contains language related input files such as lexicon.txt
""",
)
return parser return parser
@ -106,12 +136,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`: Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select - best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is the model that has the lowest training loss. It is
updated during the training. updated during the training.
@ -621,9 +645,17 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if args.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42) fix_random_seed(42)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port, args.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
@ -640,7 +672,8 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
@ -665,7 +698,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[local_rank])
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
@ -726,9 +759,21 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
if args.use_multi_node:
# for multi-node multi-GPU training
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
print(f"rank: {rank}, world_size: {world_size}")
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else: else:

View File

@ -41,6 +41,7 @@ dl_dir=$PWD/download
# data/lang_bpe_yyy if the array contains xxx, yyy # data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=( vocab_sizes=(
5000 5000
500
) )
# All files generated by this script are saved in "data". # All files generated by this script are saved in "data".
@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir $lang_dir ./local/compile_hlg.py --lang-dir $lang_dir
done done
fi fi
cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -21,14 +21,46 @@ import torch
from torch import distributed as dist from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None): def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
os.environ["MASTER_ADDR"] = "localhost" """
os.environ["MASTER_PORT"] = ( rank and world_size are used only if is_multi_node is False.
"12354" if master_port is None else str(master_port) """
) if "MASTER_ADDR" not in os.environ:
dist.init_process_group("nccl", rank=rank, world_size=world_size) os.environ["MASTER_ADDR"] = "localhost"
torch.cuda.set_device(rank)
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
if is_multi_node is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")
def cleanup_dist(): def cleanup_dist():
dist.destroy_process_group() dist.destroy_process_group()
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))