mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
support multinode multigpu
update
This commit is contained in:
parent
8e296b7047
commit
f26dd3ba17
@ -1,22 +0,0 @@
|
|||||||
export PYTHONPATH=$(pwd)/../../..
|
|
||||||
|
|
||||||
./zipformer/pretrain.py \
|
|
||||||
--world-size 8 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir zipformer/exp_pretrain \
|
|
||||||
--max-duration 650 \
|
|
||||||
--quadratic-duration 512 \
|
|
||||||
--accum-grad 1 \
|
|
||||||
--do-normalize 1 \
|
|
||||||
--mask-prob 0.8 \
|
|
||||||
--extractor-mode "layer_norm" \
|
|
||||||
--dropout-input 0.0 \
|
|
||||||
--dropout-features 0.0 \
|
|
||||||
--feature-grad-mult 1.0 \
|
|
||||||
--num-encoder-layers 2,2,3,4,3,2 \
|
|
||||||
--feedforward-dim 512,768,1024,1536,1024,768 \
|
|
||||||
--encoder-dim 192,256,448,768,448,192 \
|
|
||||||
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
|
||||||
--base-lr 0.045
|
|
||||||
116
egs/librilight/SSL/run_multi_node_multi_gpu.sh
Executable file
116
egs/librilight/SSL/run_multi_node_multi_gpu.sh
Executable file
@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Yifan Yang)
|
||||||
|
#
|
||||||
|
# This script is the entry point to start model training
|
||||||
|
# with multi-node multi-GPU.
|
||||||
|
#
|
||||||
|
# Read the usage instructions below for how to run this script.
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# DDP related parameters
|
||||||
|
master_addr=
|
||||||
|
node_rank=
|
||||||
|
num_nodes=4
|
||||||
|
master_port=12354
|
||||||
|
|
||||||
|
. shared/parse_options.sh
|
||||||
|
|
||||||
|
function usage() {
|
||||||
|
echo "Usage: "
|
||||||
|
echo ""
|
||||||
|
echo " $0 \\"
|
||||||
|
echo " --master-addr <IP of master> \\"
|
||||||
|
echo " --master-port <Port of master> \\"
|
||||||
|
echo " --node-rank <rank of this node> \\"
|
||||||
|
echo " --num-nodes <Number of node>"
|
||||||
|
echo ""
|
||||||
|
echo " --master-addr The ip address of the master node."
|
||||||
|
echo " --master-port The port of the master node."
|
||||||
|
echo " --node-rank Rank of this node."
|
||||||
|
echo " --num-nodes Number of nodes in DDP training."
|
||||||
|
echo ""
|
||||||
|
echo "Usage example:"
|
||||||
|
echo "Suppose you want to use DDP with two machines:"
|
||||||
|
echo " (1) Machine 1 has 4 GPUs. You want to use"
|
||||||
|
echo " GPU 0, 1, and 3 for training"
|
||||||
|
echo " IP of machine 1 is: 10.177.41.71"
|
||||||
|
echo " (2) Machine 2 has 4 GPUs. You want to use"
|
||||||
|
echo " GPU 0, 2, and 3 for training"
|
||||||
|
echo " IP of machine 2 is: 10.177.41.72"
|
||||||
|
echo "You want to select machine 1 as the master node and"
|
||||||
|
echo "assume that the port 1234 is free on machine 1."
|
||||||
|
echo ""
|
||||||
|
echo "On machine 1, you run:"
|
||||||
|
echo ""
|
||||||
|
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
|
||||||
|
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
|
||||||
|
echo ""
|
||||||
|
echo "On machine 2, you run:"
|
||||||
|
echo ""
|
||||||
|
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
|
||||||
|
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
|
||||||
|
echo ""
|
||||||
|
echo "Note 1:"
|
||||||
|
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
|
||||||
|
echo ""
|
||||||
|
echo "Note 2:"
|
||||||
|
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
|
||||||
|
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
default='\033[0m'
|
||||||
|
bold='\033[1m'
|
||||||
|
red='\033[31m'
|
||||||
|
|
||||||
|
function error() {
|
||||||
|
printf "${bold}${red}[ERROR]${default} $1\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
|
||||||
|
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
|
||||||
|
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
|
||||||
|
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
|
||||||
|
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
|
||||||
|
|
||||||
|
# Number of GPUs this node has
|
||||||
|
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
|
||||||
|
|
||||||
|
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
|
||||||
|
echo "num_gpus: $num_gpus"
|
||||||
|
echo "master_addr: $master_addr"
|
||||||
|
|
||||||
|
export MASTER_ADDR=$master_addr
|
||||||
|
export MASTER_PORT=$master_port
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
torchrun \
|
||||||
|
--nproc_per_node $num_gpus \
|
||||||
|
--nnodes $num_nodes \
|
||||||
|
--node_rank $node_rank \
|
||||||
|
--master_addr $master_addr \
|
||||||
|
--master_port $master_port \
|
||||||
|
zipformer/pretrain.py \
|
||||||
|
--use-multi-node 1 \
|
||||||
|
--master-port $master_port \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer/exp_pretrain \
|
||||||
|
--max-duration 600 \
|
||||||
|
--quadratic-duration 1024 \
|
||||||
|
--accum-grad 1 \
|
||||||
|
--do-normalize 1 \
|
||||||
|
--mask-prob 0.8 \
|
||||||
|
--dropout-input 0.0 \
|
||||||
|
--dropout-features 0.0 \
|
||||||
|
--feature-grad-mult 1.0 \
|
||||||
|
--num-encoder-layers 2,2,3,4,3,2 \
|
||||||
|
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||||
|
--encoder-dim 192,256,448,768,448,192 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||||
|
--base-lr 0.045
|
||||||
@ -20,23 +20,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|
||||||
|
|
||||||
# For hubert model pretraining:
|
|
||||||
./zipformer/pretrain.py \
|
|
||||||
--world-size 8 \
|
|
||||||
--num-epochs 400 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir hubert/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 87.5 \
|
|
||||||
--accum-grad 4
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
@ -46,7 +29,6 @@ from pathlib import Path
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import lhotse
|
|
||||||
import optim
|
import optim
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
@ -69,7 +51,13 @@ from icefall.checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import (
|
||||||
|
cleanup_dist,
|
||||||
|
get_local_rank,
|
||||||
|
get_rank,
|
||||||
|
get_world_size,
|
||||||
|
setup_dist,
|
||||||
|
)
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -405,6 +393,15 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-multi-node",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True if using multi-node multi-GPU.
|
||||||
|
You are not supposed to set it directly.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -572,7 +569,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accum-grad",
|
"--accum-grad",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=1,
|
||||||
help="""update gradient when batch_idx_train % accum_grad == 0.
|
help="""update gradient when batch_idx_train % accum_grad == 0.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -1090,8 +1087,15 @@ def run(rank, world_size, args):
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
|
if params.use_multi_node:
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
else:
|
||||||
|
local_rank = rank
|
||||||
|
logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
setup_dist(rank, world_size, params.master_port)
|
setup_dist(rank, world_size, params.master_port, params.use_multi_node)
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
@ -1103,8 +1107,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", local_rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -1127,7 +1131,7 @@ def run(rank, world_size, args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||||
@ -1358,12 +1362,18 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
world_size = args.world_size
|
if args.use_multi_node:
|
||||||
assert world_size >= 1
|
rank = get_rank()
|
||||||
if world_size > 1:
|
world_size = get_world_size()
|
||||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
args.world_size = world_size
|
||||||
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
else:
|
else:
|
||||||
run(rank=0, world_size=1, args=args)
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|||||||
@ -103,7 +103,7 @@ class LibriLightDataModule:
|
|||||||
help="We will draw this many cuts to estimate the duration"
|
help="We will draw this many cuts to estimate the duration"
|
||||||
"bins for creating similar-duration buckets. Larger number"
|
"bins for creating similar-duration buckets. Larger number"
|
||||||
"means a better estimate to the data distribution, possibly"
|
"means a better estimate to the data distribution, possibly"
|
||||||
"at a longer init cost."
|
"at a longer init cost.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--quadratic-duration",
|
"--quadratic-duration",
|
||||||
@ -304,28 +304,36 @@ class LibriLightDataModule:
|
|||||||
def medium_cuts(self) -> CutSet:
|
def medium_cuts(self) -> CutSet:
|
||||||
logging.info("About to get librilight medium cuts")
|
logging.info("About to get librilight medium cuts")
|
||||||
filenames = glob.glob(
|
filenames = glob.glob(
|
||||||
str(self.args.manifest_dir / "medium_split" / "librilight_cuts_medium.*.jsonl.gz")
|
str(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ "medium_split"
|
||||||
|
/ "librilight_cuts_medium.*.jsonl.gz"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
|
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
|
||||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||||
sorted_filenames = [f[1] for f in idx_filenames]
|
sorted_filenames = [f[1] for f in idx_filenames]
|
||||||
logging.info(f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode")
|
logging.info(
|
||||||
return lhotse.combine(
|
f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode"
|
||||||
lhotse.load_manifest_lazy(p) for p in sorted_filenames
|
|
||||||
)
|
)
|
||||||
|
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def large_cuts(self) -> CutSet:
|
def large_cuts(self) -> CutSet:
|
||||||
logging.info("About to get librilight large cuts")
|
logging.info("About to get librilight large cuts")
|
||||||
filenames = glob.glob(
|
filenames = glob.glob(
|
||||||
str(self.args.manifest_dir / "large_split" / "librilight_cuts_large.*.jsonl.gz")
|
str(
|
||||||
|
self.args.manifest_dir
|
||||||
|
/ "large_split"
|
||||||
|
/ "librilight_cuts_large.*.jsonl.gz"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
|
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
|
||||||
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
|
||||||
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
|
||||||
sorted_filenames = [f[1] for f in idx_filenames]
|
sorted_filenames = [f[1] for f in idx_filenames]
|
||||||
logging.info(f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode")
|
logging.info(
|
||||||
return lhotse.combine(
|
f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode"
|
||||||
lhotse.load_manifest_lazy(p) for p in sorted_filenames
|
|
||||||
)
|
)
|
||||||
|
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user