mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
WIP: Support multi-node multi-GPU training.
This commit is contained in:
parent
707d7017a7
commit
c688161f44
@ -383,7 +383,7 @@ def decode_one_batch(
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
for lm_scale in lm_scale_list:
|
||||
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
||||
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -39,7 +39,13 @@ from transformer import Noam
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.dist import (
|
||||
cleanup_dist,
|
||||
get_local_rank,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
setup_dist,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -54,6 +60,13 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-multi-node",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True if using multi-node multi-GPU.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
@ -92,6 +105,23 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe",
|
||||
help="""It contains language related input files such as lexicon.txt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -106,12 +136,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- exp_dir: It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
||||
- lang_dir: It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
@ -621,9 +645,17 @@ def run(rank, world_size, args):
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if args.use_multi_node:
|
||||
local_rank = get_local_rank()
|
||||
else:
|
||||
local_rank = rank
|
||||
logging.info(
|
||||
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
|
||||
)
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
setup_dist(rank, world_size, params.master_port, args.use_multi_node)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
@ -640,7 +672,8 @@ def run(rank, world_size, args):
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
@ -665,7 +698,7 @@ def run(rank, world_size, args):
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
@ -726,9 +759,21 @@ def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
if args.use_multi_node:
|
||||
# for multi-node multi-GPU training
|
||||
rank = get_rank()
|
||||
world_size = get_world_size()
|
||||
args.world_size = world_size
|
||||
print(f"rank: {rank}, world_size: {world_size}")
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
return
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
|
@ -41,6 +41,7 @@ dl_dir=$PWD/download
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
5000
|
||||
500
|
||||
)
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
./local/compile_hlg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
||||
|
@ -21,14 +21,46 @@ import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
|
||||
def setup_dist(rank, world_size, master_port=None):
|
||||
def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
|
||||
"""
|
||||
rank and world_size are used only if is_multi_node is False.
|
||||
"""
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
|
||||
if "MASTER_PORT" not in os.environ:
|
||||
os.environ["MASTER_PORT"] = (
|
||||
"12354" if master_port is None else str(master_port)
|
||||
)
|
||||
|
||||
if is_multi_node is False:
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
else:
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
|
||||
def cleanup_dist():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if "WORLD_SIZE" in os.environ:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank():
|
||||
if "RANK" in os.environ:
|
||||
return int(os.environ["RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
return dist.rank()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
return int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
Loading…
x
Reference in New Issue
Block a user