mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
WIP: Support multi-node multi-GPU training.
This commit is contained in:
parent
707d7017a7
commit
c688161f44
@ -383,7 +383,7 @@ def decode_one_batch(
|
|||||||
ans[lm_scale_str] = hyps
|
ans[lm_scale_str] = hyps
|
||||||
else:
|
else:
|
||||||
for lm_scale in lm_scale_list:
|
for lm_scale in lm_scale_list:
|
||||||
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,13 @@ from transformer import Noam
|
|||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import (
|
||||||
|
cleanup_dist,
|
||||||
|
get_local_rank,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
|
setup_dist,
|
||||||
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -54,6 +60,13 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-multi-node",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="True if using multi-node multi-GPU.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -92,6 +105,23 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe",
|
||||||
|
help="""It contains language related input files such as lexicon.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -106,12 +136,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
Explanation of options saved in `params`:
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
- exp_dir: It specifies the directory where all training related
|
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
|
||||||
|
|
||||||
- lang_dir: It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
the model that has the lowest training loss. It is
|
the model that has the lowest training loss. It is
|
||||||
updated during the training.
|
updated during the training.
|
||||||
@ -621,9 +645,17 @@ def run(rank, world_size, args):
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if args.use_multi_node:
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
else:
|
||||||
|
local_rank = rank
|
||||||
|
logging.info(
|
||||||
|
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port, args.use_multi_node)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
@ -640,7 +672,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", local_rank)
|
||||||
|
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
@ -665,7 +698,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[local_rank])
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
@ -726,9 +759,21 @@ def main():
|
|||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
if args.use_multi_node:
|
||||||
|
# for multi-node multi-GPU training
|
||||||
|
rank = get_rank()
|
||||||
|
world_size = get_world_size()
|
||||||
|
args.world_size = world_size
|
||||||
|
print(f"rank: {rank}, world_size: {world_size}")
|
||||||
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
|
return
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
else:
|
else:
|
||||||
|
@ -41,6 +41,7 @@ dl_dir=$PWD/download
|
|||||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
5000
|
5000
|
||||||
|
500
|
||||||
)
|
)
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
# All files generated by this script are saved in "data".
|
||||||
@ -190,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
./local/compile_hlg.py --lang-dir $lang_dir
|
./local/compile_hlg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
|
||||||
|
@ -21,14 +21,46 @@ import torch
|
|||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def setup_dist(rank, world_size, master_port=None):
|
def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
"""
|
||||||
os.environ["MASTER_PORT"] = (
|
rank and world_size are used only if is_multi_node is False.
|
||||||
"12354" if master_port is None else str(master_port)
|
"""
|
||||||
)
|
if "MASTER_ADDR" not in os.environ:
|
||||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
|
if "MASTER_PORT" not in os.environ:
|
||||||
|
os.environ["MASTER_PORT"] = (
|
||||||
|
"12354" if master_port is None else str(master_port)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_multi_node is False:
|
||||||
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
else:
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
|
||||||
|
|
||||||
def cleanup_dist():
|
def cleanup_dist():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if "WORLD_SIZE" in os.environ:
|
||||||
|
return int(os.environ["WORLD_SIZE"])
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_world_size()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if "RANK" in os.environ:
|
||||||
|
return int(os.environ["RANK"])
|
||||||
|
elif dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.rank()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank():
|
||||||
|
return int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user