support multinode multigpu

update
This commit is contained in:
Yifan Yeung 2024-08-10 22:07:58 +08:00
parent 8e296b7047
commit f26dd3ba17
4 changed files with 172 additions and 60 deletions

View File

@ -1,22 +0,0 @@
export PYTHONPATH=$(pwd)/../../..
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 650 \
--quadratic-duration 512 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--extractor-mode "layer_norm" \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

View File

@ -0,0 +1,116 @@
#!/usr/bin/env bash
#
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Yifan Yang)
#
# This script is the entry point to start model training
# with multi-node multi-GPU.
#
# Read the usage instructions below for how to run this script.
set -e
# DDP related parameters
master_addr=
node_rank=
num_nodes=4
master_port=12354
. shared/parse_options.sh
function usage() {
echo "Usage: "
echo ""
echo " $0 \\"
echo " --master-addr <IP of master> \\"
echo " --master-port <Port of master> \\"
echo " --node-rank <rank of this node> \\"
echo " --num-nodes <Number of node>"
echo ""
echo " --master-addr The ip address of the master node."
echo " --master-port The port of the master node."
echo " --node-rank Rank of this node."
echo " --num-nodes Number of nodes in DDP training."
echo ""
echo "Usage example:"
echo "Suppose you want to use DDP with two machines:"
echo " (1) Machine 1 has 4 GPUs. You want to use"
echo " GPU 0, 1, and 3 for training"
echo " IP of machine 1 is: 10.177.41.71"
echo " (2) Machine 2 has 4 GPUs. You want to use"
echo " GPU 0, 2, and 3 for training"
echo " IP of machine 2 is: 10.177.41.72"
echo "You want to select machine 1 as the master node and"
echo "assume that the port 1234 is free on machine 1."
echo ""
echo "On machine 1, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
echo ""
echo "On machine 2, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
echo ""
echo "Note 1:"
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
echo ""
echo "Note 2:"
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
exit 1
}
default='\033[0m'
bold='\033[1m'
red='\033[31m'
function error() {
printf "${bold}${red}[ERROR]${default} $1\n"
}
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
# Number of GPUs this node has
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "num_gpus: $num_gpus"
echo "master_addr: $master_addr"
export MASTER_ADDR=$master_addr
export MASTER_PORT=$master_port
set -x
torchrun \
--nproc_per_node $num_gpus \
--nnodes $num_nodes \
--node_rank $node_rank \
--master_addr $master_addr \
--master_port $master_port \
zipformer/pretrain.py \
--use-multi-node 1 \
--master-port $master_port \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 600 \
--quadratic-duration 1024 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

View File

@ -20,23 +20,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# For hubert model pretraining:
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 400 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hubert/exp \
--full-libri 1 \
--max-duration 87.5 \
--accum-grad 4
"""
import argparse
import copy
@ -46,7 +29,6 @@ from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import lhotse
import optim
import torch
import torch.multiprocessing as mp
@ -69,7 +51,13 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
@ -405,6 +393,15 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="""True if using multi-node multi-GPU.
You are not supposed to set it directly.
""",
)
parser.add_argument(
"--world-size",
type=int,
@ -572,7 +569,7 @@ def get_parser():
parser.add_argument(
"--accum-grad",
type=int,
default=4,
default=1,
help="""update gradient when batch_idx_train % accum_grad == 0.
""",
)
@ -1090,8 +1087,15 @@ def run(rank, world_size, args):
params.update(vars(args))
fix_random_seed(params.seed)
if params.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(rank, world_size, params.master_port, params.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
@ -1103,8 +1107,8 @@ def run(rank, world_size, args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
logging.info(params)
logging.info("About to create model")
@ -1127,7 +1131,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
@ -1358,12 +1362,18 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
if args.use_multi_node:
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
run(rank=rank, world_size=world_size, args=args)
else:
run(rank=0, world_size=1, args=args)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)

View File

@ -103,7 +103,7 @@ class LibriLightDataModule:
help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly"
"at a longer init cost."
"at a longer init cost.",
)
group.add_argument(
"--quadratic-duration",
@ -304,28 +304,36 @@ class LibriLightDataModule:
def medium_cuts(self) -> CutSet:
logging.info("About to get librilight medium cuts")
filenames = glob.glob(
str(self.args.manifest_dir / "medium_split" / "librilight_cuts_medium.*.jsonl.gz")
str(
self.args.manifest_dir
/ "medium_split"
/ "librilight_cuts_medium.*.jsonl.gz"
)
)
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode")
return lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
logging.info(
f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode"
)
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def large_cuts(self) -> CutSet:
logging.info("About to get librilight large cuts")
filenames = glob.glob(
str(self.args.manifest_dir / "large_split" / "librilight_cuts_large.*.jsonl.gz")
str(
self.args.manifest_dir
/ "large_split"
/ "librilight_cuts_large.*.jsonl.gz"
)
)
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode")
return lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
logging.info(
f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode"
)
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)