Merge a133571dc7d073290b927877b54ca4cad06b0ad9 into f2387fe523c6f89987f3723bfa967095e8de5127

This commit is contained in:
Fangjun Kuang 2021-10-14 20:20:43 +08:00 committed by GitHub
commit da5dc4c504
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 224 additions and 20 deletions

View File

@ -383,7 +383,7 @@ def decode_one_batch(
ans[lm_scale_str] = hyps ans[lm_scale_str] = hyps
else: else:
for lm_scale in lm_scale_list: for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]] ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans return ans

View File

@ -0,0 +1,125 @@
#!/usr/bin/env bash
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# This script is the entry point to start model training
# with multi-node multi-GPU.
#
# Read the usage instructions below for how to run this script.
set -e
cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)
# DDP related parameters
master_addr=
node_rank=
num_nodes=
master_port=1234
# Training script parameters
# You can add more if you like
#
# Use ./conformer_ctc/train.py --help to see more
#
# If you add more parameters here, remember to append them to the
# end of this file.
#
max_duration=200
bucketing_sampler=1
full_libri=1
start_epoch=0
num_epochs=2
exp_dir=conformer_ctc/exp3
lang_dir=data/lang_bpe_500
. $cur_dir/../shared/parse_options.sh
function usage() {
echo "Usage: "
echo ""
echo " $0 \\"
echo " --master-addr <IP of master> \\"
echo " --master-port <Port of master> \\"
echo " --node-rank <rank of this node> \\"
echo " --num-nodes <Number of node>"
echo ""
echo " --master-addr The ip address of the master node."
echo " --master-port The port of the master node."
echo " --node-rank Rank of this node."
echo " --num-nodes Number of nodes in DDP training."
echo ""
echo "Usage example:"
echo "Suppose you want to use DDP with two machines:"
echo " (1) Machine 1 has 4 GPUs. You want to use"
echo " GPU 0, 1, and 3 for training"
echo " IP of machine 1 is: 10.177.41.71"
echo " (2) Machine 2 has 4 GPUs. You want to use"
echo " GPU 0, 2, and 3 for training"
echo " IP of machine 2 is: 10.177.41.72"
echo "You want to select machine 1 as the master node and"
echo "assume that the port 1234 is free on machine 1."
echo ""
echo "On machine 1, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
echo ""
echo "On machine 2, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
echo ""
echo "Note 1:"
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
echo ""
echo "Note 2:"
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
exit 1
}
default='\033[0m'
bold='\033[1m'
red='\033[31m'
function error() {
printf "${bold}${red}[ERROR]${default} $1\n"
}
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
# Number of GPUs this node has
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "num_gpus: $num_gpus"
echo "master_addr: $master_addr"
export MASTER_ADDR=$master_addr
export MASTER_PORT=$master_port
set -x
python -m torch.distributed.launch \
--use_env \
--nproc_per_node $num_gpus \
--nnodes $num_nodes \
--node_rank $node_rank \
--master_addr $master_addr \
--master_port $master_port \
\
$cur_dir/train.py \
--use-multi-node true \
--master-port $master_port \
--max-duration $max_duration \
--bucketing-sampler $bucketing_sampler \
--full-libri $full_libri \
--start-epoch $start_epoch \
--num-epochs $num_epochs \
--exp-dir $exp_dir \
--lang-dir $lang_dir

View File

@ -42,7 +42,13 @@ from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -58,6 +64,17 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="""True if using multi-node multi-GPU.
You are not supposed to set it directly.
See ./conformer_ctc/run-multi-node-multi-gpu.sh
for details.
""",
)
parser.add_argument( parser.add_argument(
"--world-size", "--world-size",
type=int, type=int,
@ -96,6 +113,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""It contains language related input files such as lexicon.txt
""",
)
return parser return parser
@ -110,12 +144,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`: Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select - best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is the model that has the lowest training loss. It is
updated during the training. updated during the training.
@ -533,9 +561,17 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if args.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42) fix_random_seed(42)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port, args.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
@ -552,7 +588,8 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
graph_compiler = BpeCtcTrainingGraphCompiler( graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
@ -577,7 +614,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[local_rank])
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
@ -638,9 +675,21 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
if args.use_multi_node:
# for multi-node multi-GPU training
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
print(f"rank: {rank}, world_size: {world_size}")
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else: else:

View File

@ -191,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir $lang_dir ./local/compile_hlg.py --lang-dir $lang_dir
done done
fi fi
cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -21,14 +21,46 @@ import torch
from torch import distributed as dist from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None): def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
"""
rank and world_size are used only if is_multi_node is False.
"""
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = ( os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port) "12354" if master_port is None else str(master_port)
) )
if is_multi_node is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size) dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")
def cleanup_dist(): def cleanup_dist():
dist.destroy_process_group() dist.destroy_process_group()
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))