add librilight ssl recipe

update

Update ssl_datamodule.py

Update pretrain.py

Update pretrain.sh

Update pretrain.sh

Update hubert_ce.py

Update pretrain.py
This commit is contained in:
Your Name 2024-08-09 18:54:10 +00:00 committed by Yifan Yeung
parent 3b257dd5ae
commit 8e296b7047
38 changed files with 3145 additions and 137 deletions

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import Counter
from pathlib import Path
import torch
from lhotse import CutSet
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuts-path",
type=str,
default="data/kmeans/librispeech_cuts_dev-clean.jsonl.gz",
)
parser.add_argument(
"--num-clusters",
type=int,
default=500,
)
return parser.parse_args()
def analyze_codebook(args):
cuts_path = Path(args.cuts_path)
assert cuts_path.is_file(), f"{cuts_path} does not exist"
logging.info(f"Loading {cuts_path}")
cut_set = CutSet.from_file(cuts_path)
cluster_counts = Counter()
logging.info("Analyzing codebook")
for cut in tqdm(cut_set):
kmeans = map(int, cut.custom["kmeans"].split())
cluster_counts.update(kmeans)
utilized_clusters = len(cluster_counts)
total_count = sum(cluster_counts.values())
counts = torch.tensor([cluster_counts[i] for i in range(args.num_clusters)])
normalized_counts = (counts / total_count).clamp(min=1e-10)
codebook_entropy = (
-(normalized_counts * normalized_counts.log()).sum()
* torch.log2(torch.tensor(torch.e))
).item()
logging.info(
f"Codebook utilization rate: {utilized_clusters / args.num_clusters:%}"
)
logging.info(f"Codebook entropy: {codebook_entropy}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
analyze_codebook(args)

View File

@ -0,0 +1,251 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Optional
import fairseq
import joblib
import numpy as np
import torch
from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy
from silero_vad import get_speech_timestamps, load_silero_vad
from tqdm import tqdm
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
class ApplyKmeans(object):
def __init__(self, km_path):
self.km_model = joblib.load(km_path)
self.C_np = self.km_model.cluster_centers_.transpose()
self.Cnorm_np = (self.C_np**2).sum(0, keepdims=True)
self.C = torch.from_numpy(self.C_np)
self.Cnorm = torch.from_numpy(self.Cnorm_np)
if torch.cuda.is_available():
self.C = self.C.cuda()
self.Cnorm = self.Cnorm.cuda()
def __call__(self, x):
if isinstance(x, torch.Tensor):
dist = (
x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
)
return dist.argmin(dim=1).cpu().numpy()
else:
dist = (
(x**2).sum(1, keepdims=True)
- 2 * np.matmul(x, self.C_np)
+ self.Cnorm_np
)
return np.argmin(dist, axis=1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--subset",
type=str,
default="small",
)
parser.add_argument(
"--model-path",
type=str,
default="download/hubert_base_ls960.pt",
)
parser.add_argument(
"--kmeans-model-path",
type=str,
default="download/hubert_base_ls960_L9_km500.model",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser.parse_args()
def extract_and_save_one_cuts(
raw_cuts_path, cuts_path, model, vad_model, apply_kmeans, do_normalize, device
):
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
logging.info("Extracting kmeans")
cuts = []
for cut in tqdm(cut_set):
assert cut.sampling_rate == 16000, f"{cut.sampling_rate}"
audio = cut.load_audio()
if audio.shape[-1] > 64 * 16000:
timestamps = get_speech_timestamps(audio, vad_model)
offsets = [i["start"] for i in timestamps]
audios = [audio[:, i["start"] : i["end"]] for i in timestamps]
logging.info(f"Trim audio {cut.id} into {len(audios)} segments")
else:
offsets = [0]
audios = [audio]
seq = 0
for audio, offset in zip(audios, offsets):
x = torch.from_numpy(audio).float().to(device)
with torch.no_grad():
if do_normalize:
x = torch.nn.functional.layer_norm(x, x.shape)
feature, _ = model.extract_features(
source=x,
padding_mask=None,
mask=False,
output_layer=9,
)
feature = feature.squeeze(0)
kmeans = " ".join(map(str, apply_kmeans(feature).tolist()))
supervision_segment = fastcopy(
cut.supervisions[0],
id=f"{cut.id}-{seq}",
start=0.0,
duration=audio.shape[-1] / 16000,
)
cut_with_kmeans = fastcopy(
cut,
id=f"{cut.id}-{seq}",
start=cut.start + offset / 16000,
duration=audio.shape[-1] / 16000,
supervisions=[supervision_segment],
custom={"kmeans": kmeans},
)
cuts.append(cut_with_kmeans)
seq += 1
cuts = CutSet(cuts)
logging.info(f"Saving to {cuts_path}")
cuts.to_file(cuts_path)
def extract_kmeans(args):
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
output_dir = (
f"data/kmeans/{args.subset}_split" if args.subset != "small" else "data/kmeans"
)
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
prefix = "librilight"
vad_model = load_silero_vad()
apply_kmeans = ApplyKmeans(args.kmeans_model_path)
model, _, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[args.model_path]
)
model = model[0].eval().to(device)
do_normalize = task.cfg.normalize
if args.subset == "small":
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
return
raw_cuts_path = output_dir / f"{prefix}_cuts_{args.subset}_raw.jsonl.gz"
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
return
extract_and_save_one_cuts(
raw_cuts_path,
cuts_path,
model,
vad_model,
apply_kmeans,
do_normalize,
device,
)
else:
num_digits = 8 # num_digits is fixed by lhotse split-lazy
start = args.start
stop = args.stop
assert stop > start, "stop must be larger than start!"
for i in range(start, stop):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{stop - 1}")
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = (
output_dir / f"{prefix}_cuts_{args.subset}_raw.{idx}.jsonl.gz"
)
if not raw_cuts_path.is_file():
logging.info(f"{raw_cuts_path} does not exist - skipping it")
continue
extract_and_save_one_cuts(
raw_cuts_path,
cuts_path,
model,
vad_model,
apply_kmeans,
do_normalize,
device,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
extract_kmeans(args)

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
return parser.parse_args()
def preprocess_librilight(
dataset: Optional[str] = None,
):
src_dir = Path("data/manifests")
output_dir = Path("data/kmeans")
if dataset is None:
dataset_parts = (
"small",
"medium",
"large",
)
else:
dataset_parts = dataset.split(" ", -1)
prefix = "librilight"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}_raw.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False, min_duration=None
)
logging.info(f"Saving to {output_dir / cuts_filename}")
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
preprocess_librilight(
dataset=args.dataset,
)

87
egs/librilight/SSL/prepare.sh Executable file
View File

@ -0,0 +1,87 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
# run step 0 to step 5 by default
stage=0
stop_stage=5
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriLight
# - small
# - medium
# - large
#
# You can download them from
# - https://dl.fbaipublicfiles.com/librilight/data/small.tar
# - https://dl.fbaipublicfiles.com/librilight/data/medium.tar
# - https://dl.fbaipublicfiles.com/librilight/data/large.tar
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Running prepare.sh"
log "dl_dir: $dl_dir"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare Libri-Light manifest"
# We assume that you have downloaded the Libri-Light corpus
# to $dl_dir/LibriLight
mkdir -p data/manifests
if [ ! -e data/manifests/.librilight.done ]; then
lhotse prepare librilight -j $nj $dl_dir/LibriLight data/manifests
touch data/manifests/.librilight.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Preprocess Libri-Light manifest"
mkdir -p data/kmeans
if [ ! -f data/kmeans/.preprocess_complete ]; then
python3 ./local/preprocess_librilight.py
touch data/fbank/.preprocess_complete
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Split medium and large subset into pieces"
num_per_split=200000
split_dir=data/kmeans/medium_split
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/kmeans/librilight_cuts_medium_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.split_completed
fi
split_dir=data/kmeans/large_split
if [ ! -f $split_dir/.split_completed ]; then
lhotse split-lazy ./data/kmeans/librilight_cuts_large_raw.jsonl.gz $split_dir $num_per_split
touch $split_dir/.split_completed
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract SSL target for librilight"
mkdir -p data/fbank
if [ ! -e data/fbank/.librispeech.done ]; then
./local/compute_fbank_librispeech.py
touch data/fbank/.librispeech.done
fi
fi

22
egs/librilight/SSL/pretrain.sh Executable file
View File

@ -0,0 +1,22 @@
export PYTHONPATH=$(pwd)/../../..
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--max-duration 650 \
--quadratic-duration 512 \
--accum-grad 1 \
--do-normalize 1 \
--mask-prob 0.8 \
--extractor-mode "layer_norm" \
--dropout-input 0.0 \
--dropout-features 0.0 \
--feature-grad-mult 1.0 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

1
egs/librilight/SSL/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

2
egs/librilight/SSL/zipformer/decode.py Normal file → Executable file
View File

@ -1015,8 +1015,6 @@ def main():
test_sets = ["dev-clean", "dev-other", "test-clean", "test-other"]
test_dl = [dev_clean_dl, dev_other_dl, test_clean_dl, test_other_dl]
# test_sets = ["dev-clean", "dev-other"]
# test_dl = [dev_clean_dl, dev_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

19
egs/librilight/SSL/zipformer/finetune.py Normal file → Executable file
View File

@ -1,11 +1,10 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@ -1246,7 +1245,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1388,6 +1387,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1396,6 +1397,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

71
egs/librilight/SSL/zipformer/pretrain.py Normal file → Executable file
View File

@ -1,11 +1,10 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
#
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Zengwei Yao,
# Yifan Yang,
# Daniel Povey)
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@ -32,7 +31,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
--num-epochs 400 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--exp-dir hubert/exp \
--full-libri 1 \
--max-duration 87.5 \
--accum-grad 4
"""
@ -46,6 +46,7 @@ from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import lhotse
import optim
import torch
import torch.multiprocessing as mp
@ -398,13 +399,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
and the value should be the multiple of 4, for faster computation""",
)
parser.add_argument(
"--untie-final-proj",
type=bool,
default=False,
help="use separate projection for each target",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -483,7 +477,7 @@ def get_parser():
parser.add_argument(
"--lr-epochs",
type=float,
default=10.5,
default=0.2,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
@ -541,7 +535,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=100000,
default=10000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -554,7 +548,7 @@ def get_parser():
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
default=100000,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -591,17 +585,24 @@ def get_parser():
)
parser.add_argument(
"--max-sample-size",
type=float,
default=250000,
help="max sample size",
"--max-keep-size",
type=int,
default=1024000,
help="exclude sample longer than this.",
)
parser.add_argument(
"--min-sample-size",
"--min-keep-size",
type=float,
default=32000,
help="min sample size",
help="exclude sample longer less than this.",
)
parser.add_argument(
"--max-sample-size",
type=float,
default=1024000,
help="max sample size to crop to for batching.",
)
add_model_arguments(parser)
@ -960,10 +961,10 @@ def train_one_epoch(
else:
continue
except: # noqa
except Exception as e: # noqa
save_bad_model()
display_and_save_batch(batch, params=params)
raise
raise e
if params.print_diagnostics and batch_idx == 5:
return
@ -1064,7 +1065,7 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
if batch_idx % params.accum_grad != params.accum_grad - 1:
if sub_batch_idx % params.accum_grad != params.accum_grad - 1:
optimizer.zero_grad()
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -1165,7 +1166,7 @@ def run(rank, world_size, args):
librilight = LibriLightDataModule(args)
train_cuts = librilight.train_all_shuf_cuts()
train_cuts = librilight.all_shuf_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1177,11 +1178,11 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
if (
c.duration < params.min_sample_size / params.sample_rate
or c.duration > params.max_sample_size / params.sample_rate
c.duration < params.min_keep_size / params.sample_rate
or c.duration > params.max_keep_size / params.sample_rate
):
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
@ -1198,6 +1199,7 @@ def run(rank, world_size, args):
train_dl = librilight.train_dataloaders(
train_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
@ -1205,6 +1207,8 @@ def run(rank, world_size, args):
num_classes=params.num_classes,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librilight.dev_clean_cuts()
@ -1213,12 +1217,15 @@ def run(rank, world_size, args):
valid_dl = librilight.valid_dataloaders(
valid_cuts,
max_sample_size=params.max_sample_size,
sample_rate=params.sample_rate,
label_rate=params.label_rate,
random_crop=params.random_crop,
pad_audio=False,
num_classes=params.num_classes,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:
@ -1339,7 +1346,7 @@ def scan_pessimistic_batches_for_oom(
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params)
raise
raise e
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)

View File

@ -1,5 +1,4 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang)
# Copyright 2024 Xiaomi Corporation (Author: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -25,8 +24,9 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
import lhotse
from dataset import HubertDataset
from lhotse import CutSet, combine, load_manifest_lazy
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -46,7 +46,7 @@ class LibriLightDataModule:
"""
DataModule for SSL experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
but there can be multiple test dataloaders (e.g. LibriLight test-clean
and test-other).
It contains all the common data pipeline modules used in SSL
@ -63,7 +63,7 @@ class LibriLightDataModule:
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR SSL related options",
title="SSL data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies.",
@ -92,10 +92,29 @@ class LibriLightDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=1000,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--num-cuts-for-bins-estimate",
type=float,
default=1000000,
help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly"
"at a longer init cost."
)
group.add_argument(
"--quadratic-duration",
type=float,
default=None,
help="When set, it adds an extra penalty that's quadratic"
"in size w.r.t. a cuts duration. This helps get a more"
"even GPU utilization across different input lengths when"
"models have quadratic input complexity. Set between 15"
"and 40 for transformers.",
)
group.add_argument(
"--shuffle",
type=str2bool,
@ -112,7 +131,7 @@ class LibriLightDataModule:
group.add_argument(
"--num-workers",
type=int,
default=2,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
@ -126,12 +145,13 @@ class LibriLightDataModule:
"--random-crop",
type=str2bool,
default=True,
help="audio sample rate",
help="always crop from the beginning if false",
)
def train_dataloaders(
self,
cuts_train: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
@ -139,6 +159,8 @@ class LibriLightDataModule:
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -149,6 +171,7 @@ class LibriLightDataModule:
"""
logging.info("About to create train dataset")
train = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
@ -162,9 +185,14 @@ class LibriLightDataModule:
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -172,6 +200,8 @@ class LibriLightDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -198,15 +228,19 @@ class LibriLightDataModule:
def valid_dataloaders(
self,
cuts_valid: CutSet,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
max_sample_size=max_sample_size,
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
@ -217,7 +251,10 @@ class LibriLightDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
@ -230,81 +267,11 @@ class LibriLightDataModule:
return valid_dl
def test_dataloaders(
self,
cuts: CutSet,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> DataLoader:
logging.debug("About to create test dataset")
test = HubertDataset(
sample_rate=sample_rate,
label_rate=label_rate,
random_crop=random_crop,
pad_audio=pad_audio,
num_classes=num_classes,
do_normalize=do_normalize,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def small_cuts(self) -> CutSet:
logging.info("About to get small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
)
@lru_cache()
def medium_cuts(self) -> CutSet:
logging.info("About to get medium cuts")
filenames = glob.glob(
f"{self.args.manifest_dir}/medium_splits/librilight_cuts_medium.*.jsonl.gz"
)
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
def all_shuf_cuts(self) -> CutSet:
logging.info(
f"Loading LibriLight medium {len(sorted_filenames)} splits in lazy mode"
"About to get the shuffled librilight small, medium and large cuts"
)
return combine(load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def large_cuts(self) -> CutSet:
logging.info("About to get large cuts")
filenames = glob.glob(
f"{self.args.manifest_dir}/large_splits/librilight_cuts_large.*.jsonl.gz"
)
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(
f"Loading LibriLight large {len(sorted_filenames)} splits in lazy mode"
)
return combine(load_manifest_lazy(p) for p in sorted_filenames)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info("About to get the shuffled small, medium and large cuts")
small_cuts = self.small_cuts()
medium_cuts = self.medium_cuts()
large_cuts = self.large_cuts()
@ -313,22 +280,52 @@ class LibriLightDataModule:
medium_cuts,
large_cuts,
weights=[
122867, # len(small_cuts)
1104071, # len(medium_cuts)
11012085, # len(large_cuts)
229051, # len(small_cuts)
2022949, # len(medium_cuts)
19883414, # len(large_cuts)
],
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
logging.info("About to get librispeech dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
def small_cuts(self) -> CutSet:
logging.info("About to get librilight small cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
self.args.manifest_dir / "librilight_cuts_small.jsonl.gz"
)
@lru_cache()
def medium_cuts(self) -> CutSet:
logging.info("About to get librilight medium cuts")
filenames = glob.glob(
str(self.args.manifest_dir / "medium_split" / "librilight_cuts_medium.*.jsonl.gz")
)
pattern = re.compile(r"librilight_cuts_medium.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading Libri-Light medium {len(sorted_filenames)} splits in lazy mode")
return lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
@lru_cache()
def large_cuts(self) -> CutSet:
logging.info("About to get librilight large cuts")
filenames = glob.glob(
str(self.args.manifest_dir / "large_split" / "librilight_cuts_large.*.jsonl.gz")
)
pattern = re.compile(r"librilight_cuts_large.([0-9]+).jsonl.gz")
idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
sorted_filenames = [f[1] for f in idx_filenames]
logging.info(f"Loading Libri-Light large {len(sorted_filenames)} splits in lazy mode")
return lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)

View File

@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
cuts_train: CutSet,
do_normalize: bool,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -150,7 +152,10 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -158,6 +163,8 @@ class LibriSpeechAsrDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -181,13 +188,21 @@ class LibriSpeechAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
def valid_dataloaders(
self,
cuts_valid: CutSet,
do_normalize: bool,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertAsrDataset(do_normalize=do_normalize)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

0
egs/librispeech/SSL/hubert/decode.py Normal file → Executable file
View File

0
egs/librispeech/SSL/hubert/decode_ce.py Normal file → Executable file
View File

4
egs/librispeech/SSL/hubert/finetune.py Normal file → Executable file
View File

@ -1090,6 +1090,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1098,6 +1100,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

4
egs/librispeech/SSL/hubert/finetune_ce.py Normal file → Executable file
View File

@ -1090,6 +1090,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1098,6 +1100,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

0
egs/librispeech/SSL/hubert/pretrain.py Normal file → Executable file
View File

0
egs/librispeech/SSL/hubert/pretrain_ce.py Normal file → Executable file
View File

View File

@ -144,6 +144,8 @@ class LibriSpeechDataModule:
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -170,7 +172,10 @@ class LibriSpeechDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -178,6 +183,8 @@ class LibriSpeechDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -211,6 +218,8 @@ class LibriSpeechDataModule:
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
@ -226,6 +235,8 @@ class LibriSpeechDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

19
egs/librispeech/SSL/pretrain.sh Executable file
View File

@ -0,0 +1,19 @@
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 300 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--full-libri 1 \
--max-duration 600 \
--accum-grad 1 \
--do-normalize 0 \
--mask-prob 0.8 \
--dropout-input 0.1 \
--dropout-features 0.1 \
--feature-grad-mult 0.1 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

0
egs/librispeech/SSL/zipformer/decode.py Normal file → Executable file
View File

4
egs/librispeech/SSL/zipformer/finetune.py Normal file → Executable file
View File

@ -1387,6 +1387,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1395,6 +1397,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

View File

@ -296,7 +296,6 @@ class HubertModel(nn.Module):
self.layer_norm = LayerNorm(self.embed)
self.untie_final_proj = cfg.untie_final_proj
self.final_proj = nn.Linear(encoder_output_dim, sum(cfg.num_classes))
# modules below are not needed during fine-tuning

View File

@ -154,9 +154,9 @@ class AsrModel(nn.Module):
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss

7
egs/librispeech/SSL/zipformer/pretrain.py Normal file → Executable file
View File

@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse
import copy
import logging
import sys
import warnings
from pathlib import Path
from shutil import copyfile
@ -594,7 +593,7 @@ def get_parser():
parser.add_argument(
"--max-keep-size",
type=int,
default=sys.maxsize,
default=320000,
help="exclude sample longer than this.",
)
@ -1218,6 +1217,8 @@ def run(rank, world_size, args):
num_classes=params.num_classes,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1233,6 +1234,8 @@ def run(rank, world_size, args):
pad_audio=False,
num_classes=params.num_classes,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

View File

@ -0,0 +1 @@
../zipformer/asr_datamodule.py

View File

@ -0,0 +1 @@
../zipformer/beam_search.py

View File

@ -0,0 +1,823 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Quandong Wang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) ctc-decoding
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method ctc-decoding
(2) 1best
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method 1best
(3) nbest
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--decoding-method nbest
(4) nbest-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method nbest-rescoring
(5) whole-lattice-rescoring
./zipformer/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--hlg-scale 0.6 \
--nbest-scale 1.0 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from finetune_ctc import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--decoding-method",
type=str,
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.0,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
default=0.6,
help="""The scale to be applied to `hlg.scores`.
""",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
audio = batch["audio"].to(device)
padding_mask = batch["padding_mask"].to(device)
encoder_out, encoder_out_lens = model.forward_encoder(audio, padding_mask)
ctc_output = model.ctc_output(encoder_out)
num_frames = encoder_out_lens.cpu()
supervision_segments = torch.stack(
(
torch.arange(audio.shape[0], dtype=torch.int32),
torch.zeros_like(num_frames, dtype=torch.int32),
num_frames,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
if params.decoding_method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = [
"".join(lexicon.token_table[idx].replace("|", " ") for idx in token_id)
for token_id in token_ids
]
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=batch["supervisions"]["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.decoding_method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
word_table:
It is the word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["cuts"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
graph_compiler=graph_compiler,
batch=batch,
word_table=word_table,
G=G,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (
"ctc-decoding",
"1best",
"nbest",
"nbest-rescoring",
"whole-lattice-rescoring",
"nbest-oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
if params.decoding_method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
else:
H = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
HLG.scores *= params.hlg_scale
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.decoding_method in (
"nbest-rescoring",
"whole-lattice-rescoring",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.decoding_method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
dev_clean_cuts = librispeech.dev_clean_cuts()
dev_other_cuts = librispeech.dev_other_cuts()
dev_clean_dl = librispeech.test_dataloaders(
dev_clean_cuts,
do_normalize=params.do_normalize,
)
dev_other_dl = librispeech.test_dataloaders(
dev_other_cuts,
do_normalize=params.do_normalize,
)
test_sets = ["dev-clean", "dev-other"]
test_dl = [dev_clean_dl, dev_other_dl]
# test_clean_cuts = librispeech.test_clean_cuts()
# test_other_cuts = librispeech.test_other_cuts()
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
# test_sets = ["test-clean", "test-other"]
# test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
graph_compiler=graph_compiler,
word_table=lexicon.word_table,
G=G,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../zipformer/dataset.py

View File

@ -0,0 +1 @@
../zipformer/encoder_interface.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/hubert_ce.py

View File

@ -0,0 +1,153 @@
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import add_sos
class AsrModel(nn.Module):
def __init__(
self,
encoder,
encoder_dim: int = 768,
vocab_size: int = 500,
):
"""CTC ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
Args:
encoder:
It is the transcription network in the paper. Its accepts
inputs: `x` of (N, T, encoder_dim).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
"""
super().__init__()
self.encoder = encoder
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_encoder(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 2-D tensor of shape (N, T).
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
if padding_mask is None:
padding_mask = torch.zeros_like(x, dtype=torch.bool)
encoder_out, padding_mask = self.encoder.extract_features(
source=x,
padding_mask=padding_mask,
mask=self.encoder.training,
)
encoder_out_lens = torch.sum(~padding_mask, dim=1)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
def forward(
self,
x: torch.Tensor,
y: k2.RaggedTensor,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, T).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the CTC loss,
"""
assert x.ndim == 2, x.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == y.dim0, (x.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
return ctc_loss, encoder_out_lens

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -0,0 +1 @@
../zipformer/scaling.py

View File

@ -0,0 +1 @@
../zipformer/utils.py

View File

@ -0,0 +1 @@
../zipformer/wav2vec2_module.py

View File

@ -0,0 +1 @@
../zipformer/zipformer.py

View File

@ -70,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
Returns:
Return a list-of-list of token IDs.
"""
assert sep in ("", "/"), sep
assert sep in ("", "/", "|"), sep
ids: List[List[int]] = []
whitespace = re.compile(r"([ \t])")
for text in texts:
if sep == "":
text = re.sub(whitespace, "", text)
elif sep == "|":
text = re.sub(r"\s+", " ", text)
text = re.sub(" ", "|", text)
else:
text = text.split(sep)
sub_ids = [