support multinode multigpu

update
This commit is contained in:
Yifan Yeung 2024-08-10 22:07:58 +08:00
parent 8e296b7047
commit f26dd3ba17
4 changed files with 172 additions and 60 deletions

View File

@ -1,22 +0,0 @@
export PYTHONPATH=$(pwd)/../../..
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 650 \
--quadratic-duration 512 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--extractor-mode "layer_norm" \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

View File

@ -0,0 +1,116 @@
#!/usr/bin/env bash
#
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Yifan Yang)
#
# This script is the entry point to start model training
# with multi-node multi-GPU.
#
# Read the usage instructions below for how to run this script.
set -e
# DDP related parameters
master_addr=
node_rank=
num_nodes=4
master_port=12354
. shared/parse_options.sh
function usage() {
echo "Usage: "
echo ""
echo " $0 \\"
echo " --master-addr <IP of master> \\"
echo " --master-port <Port of master> \\"
echo " --node-rank <rank of this node> \\"
echo " --num-nodes <Number of node>"
echo ""
echo " --master-addr The ip address of the master node."
echo " --master-port The port of the master node."
echo " --node-rank Rank of this node."
echo " --num-nodes Number of nodes in DDP training."
echo ""
echo "Usage example:"
echo "Suppose you want to use DDP with two machines:"
echo " (1) Machine 1 has 4 GPUs. You want to use"
echo " GPU 0, 1, and 3 for training"
echo " IP of machine 1 is: 10.177.41.71"
echo " (2) Machine 2 has 4 GPUs. You want to use"
echo " GPU 0, 2, and 3 for training"
echo " IP of machine 2 is: 10.177.41.72"
echo "You want to select machine 1 as the master node and"
echo "assume that the port 1234 is free on machine 1."
echo ""
echo "On machine 1, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
echo ""
echo "On machine 2, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
echo ""
echo "Note 1:"
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
echo ""
echo "Note 2:"
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
exit 1
}
default='\033[0m'
bold='\033[1m'
red='\033[31m'
function error() {
printf "${bold}${red}[ERROR]${default} $1\n"
}
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
# Number of GPUs this node has
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "num_gpus: $num_gpus"
echo "master_addr: $master_addr"
export MASTER_ADDR=$master_addr
export MASTER_PORT=$master_port
set -x
torchrun \
--nproc_per_node $num_gpus \
--nnodes $num_nodes \
--node_rank $node_rank \
--master_addr $master_addr \
--master_port $master_port \
zipformer/pretrain.py \
--use-multi-node 1 \
--master-port $master_port \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 600 \
--quadratic-duration 1024 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

View File

@ -20,23 +20,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# For hubert model pretraining:
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 400 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hubert/exp \
--full-libri 1 \
--max-duration 87.5 \
--accum-grad 4
"""
import argparse import argparse
import copy import copy
@ -46,7 +29,6 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import lhotse
import optim import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
@ -69,7 +51,13 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
@ -405,6 +393,15 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="""True if using multi-node multi-GPU.
You are not supposed to set it directly.
""",
)
parser.add_argument( parser.add_argument(
"--world-size", "--world-size",
type=int, type=int,
@ -572,7 +569,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--accum-grad", "--accum-grad",
type=int, type=int,
default=4, default=1,
help="""update gradient when batch_idx_train % accum_grad == 0. help="""update gradient when batch_idx_train % accum_grad == 0.
""", """,
) )
@ -1090,8 +1087,15 @@ def run(rank, world_size, args):
params.update(vars(args)) params.update(vars(args))
fix_random_seed(params.seed) fix_random_seed(params.seed)
if params.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port, params.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
@ -1103,8 +1107,8 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
@ -1127,7 +1131,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
optimizer = ScaledAdam( optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
@ -1358,12 +1362,18 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
world_size = args.world_size if args.use_multi_node:
assert world_size >= 1 rank = get_rank()
if world_size > 1: world_size = get_world_size()
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) args.world_size = world_size
run(rank=rank, world_size=world_size, args=args)
else: else:
run(rank=0, world_size=1, args=args) world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1) torch.set_num_threads(1)

View File

@ -103,7 +103,7 @@ class LibriLightDataModule:
help="We will draw this many cuts to estimate the duration" help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number" "bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly" "means a better estimate to the data distribution, possibly"
"at a longer init cost." "at a longer init cost.",
) )
group.add_argument( group.add_argument(
"--quadratic-duration", "--quadratic-duration",
@ -304,28 +304,36 @@ class LibriLightDataModule:
def medium_cuts(self) -> CutSet: def medium_cuts(self) -> CutSet:
logging.info("About to get librilight medium cuts") logging.info("About to get librilight medium cuts")
filenames = glob.glob( filenames = glob.glob(
str(self.args.manifest_dir / "medium_split" / "librilight_cuts_medium.*.jsonl.gz") str(
self.args.manifest_dir
/ "medium_split"
/ "librilight_cuts_medium.*.jsonl.gz"
)
) )
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz") pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames] sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode") logging.info(
return lhotse.combine( f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode"
lhotse.load_manifest_lazy(p) for p in sorted_filenames
) )
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache() @lru_cache()
def large_cuts(self) -> CutSet: def large_cuts(self) -> CutSet:
logging.info("About to get librilight large cuts") logging.info("About to get librilight large cuts")
filenames = glob.glob( filenames = glob.glob(
str(self.args.manifest_dir / "large_split" / "librilight_cuts_large.*.jsonl.gz") str(
self.args.manifest_dir
/ "large_split"
/ "librilight_cuts_large.*.jsonl.gz"
)
) )
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz") pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames] sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode") logging.info(
return lhotse.combine( f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode"
lhotse.load_manifest_lazy(p) for p in sorted_filenames
) )
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)