Merge a133571dc7d073290b927877b54ca4cad06b0ad9 into f2387fe523c6f89987f3723bfa967095e8de5127

This commit is contained in:
Fangjun Kuang 2021-10-14 20:20:43 +08:00 committed by GitHub
commit da5dc4c504
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 224 additions and 20 deletions

View File

@ -383,7 +383,7 @@ def decode_one_batch(
ans[lm_scale_str] = hyps
else:
for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]]
ans[f"{lm_scale}"] = [[] * lattice.shape[0]]
return ans

View File

@ -0,0 +1,125 @@
#!/usr/bin/env bash
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# This script is the entry point to start model training
# with multi-node multi-GPU.
#
# Read the usage instructions below for how to run this script.
set -e
cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd)
# DDP related parameters
master_addr=
node_rank=
num_nodes=
master_port=1234
# Training script parameters
# You can add more if you like
#
# Use ./conformer_ctc/train.py --help to see more
#
# If you add more parameters here, remember to append them to the
# end of this file.
#
max_duration=200
bucketing_sampler=1
full_libri=1
start_epoch=0
num_epochs=2
exp_dir=conformer_ctc/exp3
lang_dir=data/lang_bpe_500
. $cur_dir/../shared/parse_options.sh
function usage() {
echo "Usage: "
echo ""
echo " $0 \\"
echo " --master-addr <IP of master> \\"
echo " --master-port <Port of master> \\"
echo " --node-rank <rank of this node> \\"
echo " --num-nodes <Number of node>"
echo ""
echo " --master-addr The ip address of the master node."
echo " --master-port The port of the master node."
echo " --node-rank Rank of this node."
echo " --num-nodes Number of nodes in DDP training."
echo ""
echo "Usage example:"
echo "Suppose you want to use DDP with two machines:"
echo " (1) Machine 1 has 4 GPUs. You want to use"
echo " GPU 0, 1, and 3 for training"
echo " IP of machine 1 is: 10.177.41.71"
echo " (2) Machine 2 has 4 GPUs. You want to use"
echo " GPU 0, 2, and 3 for training"
echo " IP of machine 2 is: 10.177.41.72"
echo "You want to select machine 1 as the master node and"
echo "assume that the port 1234 is free on machine 1."
echo ""
echo "On machine 1, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
echo ""
echo "On machine 2, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
echo " ./conformer_ctc/run-multi-node-multi-gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
echo ""
echo "Note 1:"
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
echo ""
echo "Note 2:"
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
exit 1
}
default='\033[0m'
bold='\033[1m'
red='\033[31m'
function error() {
printf "${bold}${red}[ERROR]${default} $1\n"
}
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
# Number of GPUs this node has
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "num_gpus: $num_gpus"
echo "master_addr: $master_addr"
export MASTER_ADDR=$master_addr
export MASTER_PORT=$master_port
set -x
python -m torch.distributed.launch \
--use_env \
--nproc_per_node $num_gpus \
--nnodes $num_nodes \
--node_rank $node_rank \
--master_addr $master_addr \
--master_port $master_port \
\
$cur_dir/train.py \
--use-multi-node true \
--master-port $master_port \
--max-duration $max_duration \
--bucketing-sampler $bucketing_sampler \
--full-libri $full_libri \
--start-epoch $start_epoch \
--num-epochs $num_epochs \
--exp-dir $exp_dir \
--lang-dir $lang_dir

View File

@ -42,7 +42,13 @@ from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -58,6 +64,17 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="""True if using multi-node multi-GPU.
You are not supposed to set it directly.
See ./conformer_ctc/run-multi-node-multi-gpu.sh
for details.
""",
)
parser.add_argument(
"--world-size",
type=int,
@ -96,6 +113,23 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""It contains language related input files such as lexicon.txt
""",
)
return parser
@ -110,12 +144,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -533,9 +561,17 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
if args.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(rank, world_size, params.master_port, args.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
@ -552,7 +588,8 @@ def run(rank, world_size, args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
@ -577,7 +614,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[local_rank])
optimizer = Noam(
model.parameters(),
@ -638,9 +675,21 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
if args.use_multi_node:
# for multi-node multi-GPU training
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
print(f"rank: {rank}, world_size: {world_size}")
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:

View File

@ -191,5 +191,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir $lang_dir
done
fi
cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -21,14 +21,46 @@ import torch
from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None):
def setup_dist(rank, world_size, master_port=None, is_multi_node=False):
"""
rank and world_size are used only if is_multi_node is False.
"""
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
if is_multi_node is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")
def cleanup_dist():
dist.destroy_process_group()
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))