Merge fd31ed5b0b0bef24daea22e06bb481b5a0cd519e into 9293edc62f4a3ebf769d66cc037d4e67953440f5

This commit is contained in:
Yifan Yang 2025-07-08 15:21:30 +08:00 committed by GitHub
commit 9a2b5720c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 1970 additions and 1412 deletions

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import Counter
from pathlib import Path
import torch
from lhotse import CutSet
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuts-path",
type=str,
default="data/kmeans/librispeech_cuts_dev-clean.jsonl.gz",
)
parser.add_argument(
"--num-clusters",
type=int,
default=500,
)
return parser.parse_args()
def analyze_codebook(args):
cuts_path = Path(args.cuts_path)
assert cuts_path.is_file(), f"{cuts_path} does not exist"
logging.info(f"Loading {cuts_path}")
cut_set = CutSet.from_file(cuts_path)
cluster_counts = Counter()
logging.info("Analyzing codebook")
for cut in tqdm(cut_set):
kmeans = map(int, cut.custom["kmeans"].split())
cluster_counts.update(kmeans)
utilized_clusters = len(cluster_counts)
total_count = sum(cluster_counts.values())
counts = torch.tensor([cluster_counts[i] for i in range(args.num_clusters)])
normalized_counts = (counts / total_count).clamp(min=1e-10)
codebook_entropy = (
-(normalized_counts * normalized_counts.log()).sum()
* torch.log2(torch.tensor(torch.e))
).item()
logging.info(
f"Codebook utilization rate: {utilized_clusters / args.num_clusters:%}"
)
logging.info(f"Codebook entropy: {codebook_entropy}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
analyze_codebook(args)

View File

@ -0,0 +1,289 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
from pathlib import Path
from typing import Optional
import fairseq
import joblib
import numpy as np
import torch
from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
class ApplyKmeans(object):
def __init__(self, km_path):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np)
self.Cnorm = torch.from_numpy(self.Cnorm_np)
if torch.cuda.is_available():
self.C = self.C.cuda()
self.Cnorm = self.Cnorm.cuda()
def __call__(self, x):
if isinstance(x, torch.Tensor):
dist = (
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
)
return dist.argmin(dim=1).cpu().numpy()
else:
dist = (
(x**2).sum(1, keepdims=True)
- 2 * np.matmul(x, self.C_np)
+ self.Cnorm_np
)
return np.argmin(dist, axis=1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--subset",
type=str,
default="small",
)
parser.add_argument(
"--model-path",
type=str,
default="download/hubert_base_ls960.pt",
)
parser.add_argument(
"--kmeans-model-path",
type=str,
default="download/hubert_base_ls960_L9_km500.bin",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
parser.add_argument(
"--window-duration",
type=float,
default=300.0,
)
parser.add_argument(
"--shift-duration",
type=float,
default=250.0,
)
return parser.parse_args()
@torch.no_grad()
def extract_and_save_one_cuts(
raw_cuts_path,
cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
):
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Extracting kmeans")
cuts = []
assert window_duration >= shift_duration
window_size = int(window_duration * 16000)
shift_size = int(shift_duration * 16000)
overlap_size = window_size - shift_size
out_overlap_size = get_out_length(overlap_size)
for cut in tqdm(cut_set):
assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}"
audio = cut.load_audio()
T = audio.shape[1]
start = 0
kmeans = []
while start < T:
real_window_size = min(window_size, T - start)
audio_window = audio[:, start : start + real_window_size]
x = (
torch.from_numpy(audio_window)
.float()
.to(next(model.parameters()).device)
)
if do_normalize:
x = torch.nn.functional.layer_norm(x, x.shape)
feature, _ = model.extract_features(
source=x,
padding_mask=None,
mask=False,
output_layer=9,
)
feature = feature.squeeze(0)
current_kmeans = apply_kmeans(feature).tolist()
if start == 0:
kmeans.extend(current_kmeans)
else:
kmeans.extend(current_kmeans[out_overlap_size:])
if T - start <= window_size:
break
start += shift_size
kmeans = " ".join(map(str, kmeans))
cut_with_kmeans = fastcopy(
cut,
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)
cuts = CutSet(cuts)
logging.info(f"Saving to {cuts_path}")
cuts.to_file(cuts_path)
def extract_kmeans(args):
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
output_dir = (
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans"
)
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
prefix = "librilight"
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.model_path]
)
model = model[0].eval().to(device)
do_normalize = task.cfg.normalize
window_duration = args.window_duration
shift_duration = args.shift_duration
if args.subset == "small":
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
return
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz"
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
return
extract_and_save_one_cuts(
raw_cuts_path,
cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
)
else:
num_digits = 8 # num_digits is fixed by lhotse split-lazy
start = args.start
stop = args.stop
assert stop > start, "stop must be larger than start!"
for i in range(start, stop):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{stop - 1}")
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = (
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz"
)
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
continue
extract_and_save_one_cuts(
raw_cuts_path,
cuts_path,
model,
apply_kmeans,
do_normalize,
window_duration,
shift_duration,
)
def get_out_length(T):
conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
for i, (out_channels, kernel_size, stride) in enumerate(conv_layers):
T = math.floor((T - kernel_size) / stride) + 1
return max(0, T)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
extract_kmeans(args)

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
def preprocess_librilight(
dataset: Optional[str] = None,
):
src_dir = Path("data/manifests")
output_dir = Path("data/kmeans")
if dataset is None:
dataset_parts = (
"small",
"medium",
"large",
)
else:
dataset_parts = dataset.split(" ", -1)
prefix = "librilight"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {output_dir / cuts_filename}")
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
preprocess_librilight(
dataset=args.dataset,
)

100
egs/librilight/SSL/prepare.sh Executable file
View File

@ -0,0 +1,100 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=32
# run step 0 to step 4 by default
stage=0
stop_stage=4
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriLight
# - small
# - medium
# - large
#
# You can download them from
# - https://dl.fbaipublicfiles.com/librilight/data/small.tar
# - https://dl.fbaipublicfiles.com/librilight/data/medium.tar
# - https://dl.fbaipublicfiles.com/librilight/data/large.tar
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Running prepare.sh"
log "dl_dir: $dl_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare Libri-Light manifest"
# We assume that you have downloaded the Libri-Light corpus
# to $dl_dir/LibriLight
mkdir -p data/manifests
if [ ! -e data/manifests/.librilight.done ]; then
lhotse prepare librilight -j $nj $dl_dir/LibriLight data/manifests
touch data/manifests/.librilight.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess Libri-Light manifest"
mkdir -p data/kmeans
if [ ! -f data/kmeans/.preprocess_complete ]; then
python3 ./local/preprocess_librilight.py
touch data/kmeans/.preprocess_complete
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split medium and large subset into pieces"
num_per_split=10000
split_dir=data/kmeans/medium_split
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.split_completed
fi
split_dir=data/kmeans/large_split
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/kmeans/librilight_cuts_large_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.split_completed
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract SSL target for librilight"
if [ ! -e download/hubert_base_ls960.pt ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt -P download
fi
if [ ! -e download/hubert_base_ls960_L9_km500.bin ]; then
wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download
fi
if [ ! -e data/kmeans/.extract_small.done ]; then
./local/extract_kmeans.py --subset small
touch data/kmeans/.extract_small.done
fi
if [ ! -e data/kmeans/.extract_medium.done ]; then
./local/extract_kmeans.py --subset medium
touch data/kmeans/.extract_medium.done
fi
if [ ! -e data/kmeans/.extract_large.done ]; then
./local/extract_kmeans.py --subset large
touch data/kmeans/.extract_large.done
fi
fi

View File

@ -0,0 +1,116 @@
#!/usr/bin/env bash
#
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Yifan Yang)
#
# This script is the entry point to start model training
# with multi-node multi-GPU.
#
# Read the usage instructions below for how to run this script.
set -e
# DDP related parameters
master_addr=
node_rank=
num_nodes=4
master_port=12354
. shared/parse_options.sh
function usage() {
echo "Usage: "
echo ""
echo " $0 \\"
echo " --master-addr <IP of master> \\"
echo " --master-port <Port of master> \\"
echo " --node-rank <rank of this node> \\"
echo " --num-nodes <Number of node>"
echo ""
echo " --master-addr The ip address of the master node."
echo " --master-port The port of the master node."
echo " --node-rank Rank of this node."
echo " --num-nodes Number of nodes in DDP training."
echo ""
echo "Usage example:"
echo "Suppose you want to use DDP with two machines:"
echo " (1) Machine 1 has 4 GPUs. You want to use"
echo " GPU 0, 1, and 3 for training"
echo " IP of machine 1 is: 10.177.41.71"
echo " (2) Machine 2 has 4 GPUs. You want to use"
echo " GPU 0, 2, and 3 for training"
echo " IP of machine 2 is: 10.177.41.72"
echo "You want to select machine 1 as the master node and"
echo "assume that the port 1234 is free on machine 1."
echo ""
echo "On machine 1, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,1,3\""
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 0 --num-nodes 2"
echo ""
echo "On machine 2, you run:"
echo ""
echo " export CUDA_VISIBLE_DEVICES=\"0,2,3\""
echo " ./run_multi_node_multi_gpu.sh --master-addr 10.177.41.71 --master-port 1234 --node-rank 1 --num-nodes 2"
echo ""
echo "Note 1:"
echo " You use CUDA_VISIBLE_DEVICES to decide which GPUs are used for training."
echo ""
echo "Note 2:"
echo " If you use torch < 1.9.0, then every node has to use the same number of GPUs for training."
echo " If you use torch >= 1.9.0, different nodes can have a different number of GPUs for training."
exit 1
}
default='\033[0m'
bold='\033[1m'
red='\033[31m'
function error() {
printf "${bold}${red}[ERROR]${default} $1\n"
}
[ ! -z $CUDA_VISIBLE_DEVICES ] || ( echo; error "Please set CUDA_VISIBLE_DEVICES"; echo; usage )
[ ! -z $master_addr ] || ( echo; error "Please set --master-addr"; echo; usage )
[ ! -z $master_port ] || ( echo; error "Please set --master-port"; echo; usage )
[ ! -z $node_rank ] || ( echo; error "Please set --node-rank"; echo; usage )
[ ! -z $num_nodes ] || ( echo; error "Please set --num-nodes"; echo; usage )
# Number of GPUs this node has
num_gpus=$(python3 -c "s=\"$CUDA_VISIBLE_DEVICES\"; print(len(s.split(',')))")
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "num_gpus: $num_gpus"
echo "master_addr: $master_addr"
export MASTER_ADDR=$master_addr
export MASTER_PORT=$master_port
set -x
torchrun \
--nproc_per_node $num_gpus \
--nnodes $num_nodes \
--node_rank $node_rank \
--master_addr $master_addr \
--master_port $master_port \
zipformer/pretrain.py \
--use-multi-node 1 \
--master-port $master_port \
--num-epochs 20 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 350 \
--quadratic-duration 1024 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 768,1536,2048,3072,2048,1536 \
--encoder-dim 256,512,768,1024,768,512 \
--encoder-unmasked-dim 256,256,256,320,256,256 \
--base-lr 0.045

1
egs/librilight/SSL/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -1 +0,0 @@
../../../librispeech/SSL/zipformer/asr_datamodule.py

File diff suppressed because it is too large Load Diff

149
egs/librilight/SSL/zipformer/pretrain.py Normal file → Executable file
View File

@ -1,11 +1,10 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@ -21,22 +20,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# For hubert model pretraining:
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 400 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 87.5 \
--accum-grad 4
"""
import argparse
import copy
@ -68,7 +51,13 @@ from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
@ -398,19 +387,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
and the value should be the multiple of 4, for faster computation""",
)
parser.add_argument(
"--untie-final-proj",
type=bool,
default=False,
help="use separate projection for each target",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--use-multi-node",
type=str2bool,
default=False,
help="""True if using multi-node multi-GPU.
You are not supposed to set it directly.
""",
)
parser.add_argument(
"--world-size",
type=int,
@ -481,17 +472,17 @@ def get_parser():
)
parser.add_argument(
"--lr-epochs",
"--lr-hours",
type=float,
default=10.5,
help="""Number of epochs that affects how rapidly the learning rate decreases.
default=10000,
help="""Number of hours that affects how rapidly the learning rate decreases.
""",
)
parser.add_argument(
"--warmup-batches",
type=float,
default=5000,
default=1000,
help="Eden warmup steps",
)
@ -541,7 +532,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=100000,
default=10000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -554,7 +545,7 @@ def get_parser():
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
default=100000,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -578,7 +569,7 @@ def get_parser():
parser.add_argument(
"--accum-grad",
type=int,
default=4,
default=1,
help="""update gradient when batch_idx_train % accum_grad == 0.
""",
)
@ -591,17 +582,24 @@ def get_parser():
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size",
"--max-keep-size",
type=int,
default=1024000,
help="exclude sample longer than this.",
)
parser.add_argument(
"--min-sample-size",
"--min-keep-size",
type=float,
default=32000,
help="min sample size",
help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=1024000,
help="max sample size to crop to for batching.",
)
add_model_arguments(parser)
@ -953,6 +951,13 @@ def train_one_epoch(
if sub_batch_idx % params.accum_grad == params.accum_grad - 1:
params.batch_idx_train += 1
scheduler.step_batch(params.batch_idx_train)
# Use the number of hours of speech to adjust the learning rate
scheduler.step_epoch(
params.batch_idx_train
* params.max_duration
* params.world_size
/ 3600
)
scaler.step(optimizer)
scaler.update()
@ -960,10 +965,10 @@ def train_one_epoch(
else:
continue
except: # noqa
except Exception as e: # noqa
save_bad_model()
display_and_save_batch(batch, params=params)
raise
raise e
if params.print_diagnostics and batch_idx == 5:
return
@ -1064,7 +1069,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1089,8 +1094,15 @@ def run(rank, world_size, args):
params.update(vars(args))
fix_random_seed(params.seed)
if params.use_multi_node:
local_rank = get_local_rank()
else:
local_rank = rank
logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(rank, world_size, params.master_port, params.use_multi_node)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
@ -1102,8 +1114,8 @@ def run(rank, world_size, args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
device = torch.device("cuda", local_rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
logging.info(params)
logging.info("About to create model")
@ -1126,7 +1138,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
@ -1137,7 +1149,7 @@ def run(rank, world_size, args):
scheduler = Eden(
optimizer,
params.lr_batches,
params.lr_epochs,
params.lr_hours,
params.warmup_batches,
params.warmup_start,
)
@ -1165,7 +1177,7 @@ def run(rank, world_size, args):
librilight = LibriLightDataModule(args)
train_cuts = librilight.train_all_shuf_cuts()
train_cuts = librilight.all_shuf_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1177,11 +1189,11 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
if (
c.duration < params.min_sample_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate
c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_keep_size / params.sample_rate
):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
@ -1198,6 +1210,7 @@ def run(rank, world_size, args):
train_dl = librilight.train_dataloaders(
train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
@ -1205,6 +1218,8 @@ def run(rank, world_size, args):
num_classes=params.num_classes,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librilight.dev_clean_cuts()
@ -1213,12 +1228,15 @@ def run(rank, world_size, args):
valid_dl = librilight.valid_dataloaders(
valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
pad_audio=False,
num_classes=params.num_classes,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:
@ -1235,7 +1253,6 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
@ -1339,7 +1356,7 @@ def scan_pessimistic_batches_for_oom(
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params)
raise
raise e
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
@ -1351,12 +1368,18 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
if args.use_multi_node:
rank = get_rank()
world_size = get_world_size()
args.world_size = world_size
run(rank=rank, world_size=world_size, args=args)
else:
run(rank=0, world_size=1, args=args)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)

View File

@ -1,5 +1,4 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -24,9 +23,10 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import lhotse
import torch
from dataset import HubertDataset
from lhotse import CutSet, combine, load_manifest_lazy
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -46,7 +46,7 @@ class LibriLightDataModule:
"""
DataModule for SSL experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
but there can be multiple test dataloaders (e.g. LibriLight test-clean
and test-other).
It contains all the common data pipeline modules used in SSL
@ -63,7 +63,7 @@ class LibriLightDataModule:
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR SSL related options",
title="SSL data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
@ -92,10 +92,29 @@ class LibriLightDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=1000,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--num-cuts-for-bins-estimate",
type=int,
default=1000000,
help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly"
"at a longer init cost.",
)
group.add_argument(
"--quadratic-duration",
type=float,
default=None,
help="When set, it adds an extra penalty that's quadratic"
"in size w.r.t. a cuts duration. This helps get a more"
"even GPU utilization across different input lengths when"
"models have quadratic input complexity. Set between 15"
"and 40 for transformers.",
)
group.add_argument(
"--shuffle",
type=str2bool,
@ -112,7 +131,7 @@ class LibriLightDataModule:
group.add_argument(
"--num-workers",
type=int,
default=2,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
@ -126,12 +145,13 @@ class LibriLightDataModule:
"--random-crop",
type=str2bool,
default=True,
help="audio sample rate",
help="always crop from the beginning if false",
)
def train_dataloaders(
self,
cuts_train: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
@ -139,6 +159,8 @@ class LibriLightDataModule:
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -149,6 +171,7 @@ class LibriLightDataModule:
"""
logging.info("About to create train dataset")
train = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
@ -162,9 +185,14 @@ class LibriLightDataModule:
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -172,6 +200,8 @@ class LibriLightDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -198,15 +228,19 @@ class LibriLightDataModule:
def valid_dataloaders(
self,
cuts_valid: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
@ -217,7 +251,10 @@ class LibriLightDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
@ -230,81 +267,11 @@ class LibriLightDataModule:
return valid_dl
def test_dataloaders(
self,
cuts: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def small_cuts(self) -> CutSet:
logging.info("About to get small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
)
@lru_cache()
def medium_cuts(self) -> CutSet:
logging.info("About to get medium cuts")
filenames = glob.glob(
f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz"
)
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
def all_shuf_cuts(self) -> CutSet:
logging.info(
f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode"
"About to get the shuffled librilight small, medium and large cuts"
)
return combine(load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def large_cuts(self) -> CutSet:
logging.info("About to get large cuts")
filenames = glob.glob(
f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz"
)
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode"
)
return combine(load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info("About to get the shuffled small, medium and large cuts")
small_cuts = self.small_cuts()
medium_cuts = self.medium_cuts()
large_cuts = self.large_cuts()
@ -313,22 +280,60 @@ class LibriLightDataModule:
medium_cuts,
large_cuts,
weights=[
122867, # len(small_cuts)
1104071, # len(medium_cuts)
11012085, # len(large_cuts)
229051, # len(small_cuts)
2022949, # len(medium_cuts)
19883414, # len(large_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
logging.info("About to get librispeech dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
def small_cuts(self) -> CutSet:
logging.info("About to get librilight small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
)
@lru_cache()
def medium_cuts(self) -> CutSet:
logging.info("About to get librilight medium cuts")
filenames = glob.glob(
str(
self.args.manifest_dir
/ "medium_split"
/ "librilight_cuts_medium.*.jsonl.gz"
)
)
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode"
)
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def large_cuts(self) -> CutSet:
logging.info("About to get librilight large cuts")
filenames = glob.glob(
str(
self.args.manifest_dir
/ "large_split"
/ "librilight_cuts_large.*.jsonl.gz"
)
)
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode"
)
return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames)

View File

@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
cuts_train: CutSet,
do_normalize: bool,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -150,7 +152,10 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -158,6 +163,8 @@ class LibriSpeechAsrDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -181,13 +188,21 @@ class LibriSpeechAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
def valid_dataloaders(
self,
cuts_valid: CutSet,
do_normalize: bool,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertAsrDataset(do_normalize=do_normalize)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

0
egs/librispeech/SSL/hubert/decode.py Normal file → Executable file
View File

0
egs/librispeech/SSL/hubert/decode_ce.py Normal file → Executable file
View File

4
egs/librispeech/SSL/hubert/finetune.py Normal file → Executable file
View File

@ -1091,6 +1091,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1099,6 +1101,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

4
egs/librispeech/SSL/hubert/finetune_ce.py Normal file → Executable file
View File

@ -1091,6 +1091,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1099,6 +1101,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

0
egs/librispeech/SSL/hubert/pretrain.py Normal file → Executable file
View File

0
egs/librispeech/SSL/hubert/pretrain_ce.py Normal file → Executable file
View File

View File

@ -144,6 +144,8 @@ class LibriSpeechDataModule:
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -170,7 +172,10 @@ class LibriSpeechDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -178,6 +183,8 @@ class LibriSpeechDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -211,6 +218,8 @@ class LibriSpeechDataModule:
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
@ -226,6 +235,8 @@ class LibriSpeechDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

0
egs/librispeech/SSL/zipformer/decode.py Normal file → Executable file
View File

4
egs/librispeech/SSL/zipformer/finetune.py Normal file → Executable file
View File

@ -1388,6 +1388,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1396,6 +1398,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

View File

@ -296,7 +296,6 @@ class HubertModel(nn.Module):
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning

View File

@ -154,9 +154,9 @@ class AsrModel(nn.Module):
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss

7
egs/librispeech/SSL/zipformer/pretrain.py Normal file → Executable file
View File

@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse
import copy
import logging
import sys
import warnings
from pathlib import Path
from shutil import copyfile
@ -595,7 +594,7 @@ def get_parser():
parser.add_argument(
"--max-keep-size",
type=int,
default=sys.maxsize,
default=320000,
help="exclude sample longer than this.",
)
@ -1219,6 +1218,8 @@ def run(rank, world_size, args):
num_classes=params.num_classes,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1234,6 +1235,8 @@ def run(rank, world_size, args):
pad_audio=False,
num_classes=params.num_classes,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

View File

@ -0,0 +1 @@
../zipformer/asr_datamodule.py

View File

@ -0,0 +1 @@
../zipformer/beam_search.py

View File

@ -0,0 +1,823 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Quandong Wang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) ctc-decoding
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method ctc-decoding
(2) 1best
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method 1best
(3) nbest
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method nbest
(4) nbest-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method nbest-rescoring
(5) whole-lattice-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from finetune_ctc import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--decoding-method",
type=str,
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.0,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
default=0.6,
help="""The scale to be applied to `hlg.scores`.
""",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
audio = batch["audio"].to(device)
padding_mask = batch["padding_mask"].to(device)
encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
ctc_output = model.ctc_output(encoder_out)
num_frames = encoder_out_lens.cpu()
supervision_segments = torch.stack(
(
torch.arange(audio.shape[0], dtype=torch.int32),
torch.zeros_like(num_frames, dtype=torch.int32),
num_frames,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
if params.decoding_method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = [
"".join(lexicon.token_table[idx].replace("|", " ") for idx in token_id)
for token_id in token_ids
]
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=batch["supervisions"]["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.decoding_method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
word_table:
It is the word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["cuts"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
graph_compiler=graph_compiler,
batch=batch,
word_table=word_table,
G=G,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (
"ctc-decoding",
"1best",
"nbest",
"nbest-rescoring",
"whole-lattice-rescoring",
"nbest-oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
if params.decoding_method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
else:
H = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
HLG.scores *= params.hlg_scale
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.decoding_method in (
"nbest-rescoring",
"whole-lattice-rescoring",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
dev_clean_cuts = librispeech.dev_clean_cuts()
dev_other_cuts = librispeech.dev_other_cuts()
dev_clean_dl = librispeech.test_dataloaders(
dev_clean_cuts,
do_normalize=params.do_normalize,
)
dev_other_dl = librispeech.test_dataloaders(
dev_other_cuts,
do_normalize=params.do_normalize,
)
test_sets = ["dev-clean", "dev-other"]
test_dl = [dev_clean_dl, dev_other_dl]
# test_clean_cuts = librispeech.test_clean_cuts()
# test_other_cuts = librispeech.test_other_cuts()
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
# test_sets = ["test-clean", "test-other"]
# test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
graph_compiler=graph_compiler,
word_table=lexicon.word_table,
G=G,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../zipformer/dataset.py

View File

@ -0,0 +1 @@
../zipformer/encoder_interface.py

View File

@ -1,12 +1,10 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -27,19 +25,14 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# For HuBERT model finetuning:
./hubert/finetune.py \
./zipformer_ctc/finetune_ctc.py \
--world-size 8 \
--num-epochs 200 \
--num-epochs 222 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir hubert/exp \
--exp-dir zipformer_ctc/exp \
--full-libri 0 \
--max-duration 1000
It supports finetuning with:
- transducer loss (default), with `--use-transducer True --use-ctc False`
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
--max-duration 600
"""
@ -58,9 +51,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from hubert_ce import HubertModel
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
@ -72,6 +63,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -81,6 +73,7 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
@ -415,37 +408,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="use separate projection for each target",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected
to this dimension before adding.
""",
)
parser.add_argument(
"--use-transducer",
type=str2bool,
default=True,
help="If True, use Transducer head.",
)
parser.add_argument(
"--use-ctc",
type=str2bool,
default=False,
help="If True, use CTC head.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -509,6 +471,16 @@ def get_parser():
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--pretrained-dir",
type=str,
@ -516,13 +488,6 @@ def get_parser():
It specifies the directory where the pretrained checkpoint is saved.""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--base-lr", type=float, default=0.001, help="The base learning rate."
)
@ -551,53 +516,6 @@ def get_parser():
"schedules inside the model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -720,8 +638,6 @@ def get_params() -> AttributeDict:
- valid_interval: Run validation if batch_idx % valid_interval is 0
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
@ -734,8 +650,6 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for pruned RNN-T loss
"warm_step": 2000,
"env_info": get_env_info(),
}
)
@ -758,51 +672,12 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
decoder_dim=params.decoder_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return joiner
def get_model(params: AttributeDict) -> nn.Module:
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}"
)
encoder = get_encoder_model(params)
if params.use_transducer:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
else:
decoder = None
joiner = None
model = AsrModel(
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
)
return model
@ -926,7 +801,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@ -953,44 +828,19 @@ def compute_loss(
padding_mask = batch["padding_mask"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
y = graph_compiler.texts_to_ids(texts, sep="|")
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, num_frames = model(
ctc_loss, num_frames = model(
x=audio,
padding_mask=padding_mask,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = 0.0
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
assert ctc_loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
@ -998,20 +848,15 @@ def compute_loss(
info["frames"] = num_frames.sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info
return ctc_loss, info
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -1024,7 +869,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
@ -1047,7 +892,7 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -1120,7 +965,7 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -1143,7 +988,7 @@ def train_one_epoch(
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
if params.print_diagnostics and batch_idx == 5:
@ -1232,7 +1077,7 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
@ -1246,7 +1091,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1287,15 +1132,14 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@ -1388,6 +1232,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1396,6 +1242,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:
@ -1403,7 +1251,7 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
graph_compiler=graph_compiler,
params=params,
)
@ -1428,7 +1276,7 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -1462,7 +1310,7 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
) -> None:
"""Display the batch statistics and save the batch into disk.
@ -1472,8 +1320,6 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
@ -1484,7 +1330,7 @@ def display_and_save_batch(
audio = batch["audio"]
logging.info(f"audio shape: {audio.shape}")
y = sp.encode(batch["supervisions"]["text"], out_type=int)
y = graph_compiler.texts_to_ids(batch["supervisions"]["text"], sep="|")
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
@ -1493,7 +1339,7 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -1509,7 +1355,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -1524,7 +1370,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"

View File

@ -0,0 +1 @@
../zipformer/hubert_ce.py

View File

@ -0,0 +1,153 @@
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
class AsrModel(nn.Module):
def __init__(
self,
encoder,
encoder_dim: int = 768,
vocab_size: int = 500,
):
"""CTC ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
Args:
encoder:
It is the transcription network in the paper. Its accepts
inputs: `x` of (N, T, encoder_dim).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
"""
super().__init__()
self.encoder = encoder
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 2-D tensor of shape (N, T).
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
if padding_mask is None:
padding_mask = torch.zeros_like(x, dtype=torch.bool)
encoder_out, padding_mask = self.encoder.extract_features(
source=x,
padding_mask=padding_mask,
mask=self.encoder.training,
)
encoder_out_lens = torch.sum(~padding_mask, dim=1)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
def forward(
self,
x: torch.Tensor,
y: k2.RaggedTensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, T).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the CTC loss,
"""
assert x.ndim == 2, x.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == y.dim0, (x.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
return ctc_loss, encoder_out_lens

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -0,0 +1 @@
../zipformer/scaling.py

View File

@ -0,0 +1 @@
../zipformer/utils.py

View File

@ -0,0 +1 @@
../zipformer/wav2vec2_module.py

View File

@ -0,0 +1 @@
../zipformer/zipformer.py

View File

@ -70,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
Returns:
Return a list-of-list of token IDs.
"""
assert sep in ("", "/"), sep
assert sep in ("", "/", "|"), sep
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in texts:
if sep == "":
text = re.sub(whitespace, "", text)
elif sep == "|":
text = re.sub(r"\s+", " ", text)
text = re.sub(" ", "|", text)
else:
text = text.split(sep)
sub_ids = [