From 1b2e99d374cbbc527bf8c9239d616497249ccb1d Mon Sep 17 00:00:00 2001
From: lishaojie <95117087+manbaaaa@users.noreply.github.com>
Date: Thu, 9 Nov 2023 22:07:28 +0800
Subject: [PATCH 001/123] add the pruned_transducer_stateless7_streaming recipe
for commonvoice (#1018)
* add the pruned_transducer_stateless7_streaming recipe for commonvoice
* fix the symlinks
* Update RESULTS.md
---
egs/commonvoice/ASR/RESULTS.md | 25 +
egs/commonvoice/ASR/local/compile_hlg.py | 1 +
egs/commonvoice/ASR/local/compile_lg.py | 1 +
.../compute_fbank_commonvoice_dev_test.py | 4 +-
.../ASR/local/preprocess_commonvoice.py | 10 +-
egs/commonvoice/ASR/prepare.sh | 64 +-
.../README.md | 9 +
.../beam_search.py | 1 +
.../commonvoice_fr.py | 422 ++++++
.../decode.py | 810 ++++++++++
.../decode_stream.py | 1 +
.../decoder.py | 1 +
.../encoder_interface.py | 1 +
.../export-for-ncnn-zh.py | 1 +
.../export-for-ncnn.py | 1 +
.../export-onnx.py | 1 +
.../export.py | 1 +
.../finetune.py | 1342 +++++++++++++++++
.../generate_model_from_checkpoint.py | 281 ++++
.../jit_pretrained.py | 1 +
.../jit_trace_export.py | 1 +
.../jit_trace_pretrained.py | 1 +
.../joiner.py | 1 +
.../model.py | 1 +
.../onnx_check.py | 1 +
.../onnx_model_wrapper.py | 1 +
.../onnx_pretrained.py | 1 +
.../optim.py | 1 +
.../pretrained.py | 1 +
.../scaling.py | 1 +
.../scaling_converter.py | 1 +
.../streaming-ncnn-decode.py | 1 +
.../streaming_beam_search.py | 1 +
.../streaming_decode.py | 612 ++++++++
.../test_model.py | 150 ++
.../train.py | 1256 +++++++++++++++
.../train2.py | 1257 +++++++++++++++
.../zipformer.py | 1 +
.../zipformer2.py | 1 +
icefall/shared/convert-k2-to-openfst.py | 103 +-
icefall/shared/ngram_entropy_pruning.py | 631 +-------
icefall/shared/parse_options.sh | 98 +-
42 files changed, 6260 insertions(+), 840 deletions(-)
create mode 120000 egs/commonvoice/ASR/local/compile_hlg.py
create mode 120000 egs/commonvoice/ASR/local/compile_lg.py
create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py
create mode 100644 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
create mode 100755 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
mode change 100755 => 120000 icefall/shared/convert-k2-to-openfst.py
mode change 100755 => 120000 icefall/shared/ngram_entropy_pruning.py
mode change 100755 => 120000 icefall/shared/parse_options.sh
diff --git a/egs/commonvoice/ASR/RESULTS.md b/egs/commonvoice/ASR/RESULTS.md
index 751625371..2c158d91d 100644
--- a/egs/commonvoice/ASR/RESULTS.md
+++ b/egs/commonvoice/ASR/RESULTS.md
@@ -57,3 +57,28 @@ Pretrained model is available at
The tensorboard log for training is available at
+
+
+### Commonvoice (fr) BPE training results (Pruned Stateless Transducer 7_streaming)
+
+#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
+
+See #1018 for more details.
+
+Number of model parameters: 70369391, i.e., 70.37 M
+
+The best WER for Common Voice French 12.0 (cv-corpus-12.0-2022-12-07/fr) is below:
+
+Results are:
+
+| decoding method | Test |
+|----------------------|-------|
+| greedy search | 9.95 |
+| modified beam search | 9.57 |
+| fast beam search | 9.67 |
+
+Note: This best result is trained on the full librispeech and gigaspeech, and then fine-tuned on the full commonvoice.
+
+Detailed experimental results and Pretrained model are available at
+
+
diff --git a/egs/commonvoice/ASR/local/compile_hlg.py b/egs/commonvoice/ASR/local/compile_hlg.py
new file mode 120000
index 000000000..471aa7fb4
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compile_hlg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_hlg.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/compile_lg.py b/egs/commonvoice/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/commonvoice/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
index c8f9b6ccb..a0b4d224c 100755
--- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
+++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py
@@ -56,8 +56,8 @@ def get_args():
def compute_fbank_commonvoice_dev_test(language: str):
src_dir = Path(f"data/{language}/manifests")
output_dir = Path(f"data/{language}/fbank")
- num_workers = 42
- batch_duration = 600
+ num_workers = 16
+ batch_duration = 200
subsets = ("dev", "test")
diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py
index e60459765..5f6aa3ec0 100755
--- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py
+++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py
@@ -43,9 +43,13 @@ def get_args():
return parser.parse_args()
-def normalize_text(utt: str) -> str:
+def normalize_text(utt: str, language: str) -> str:
utt = re.sub(r"[{0}]+".format("-"), " ", utt)
- return re.sub(r"[^a-zA-Z\s']", "", utt).upper()
+ utt = re.sub("’", "'", utt)
+ if language == "en":
+ return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
+ if language == "fr":
+ return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
def preprocess_commonvoice(
@@ -94,7 +98,7 @@ def preprocess_commonvoice(
for sup in m["supervisions"]:
text = str(sup.text)
orig_text = text
- sup.text = normalize_text(sup.text)
+ sup.text = normalize_text(sup.text, language)
text = str(sup.text)
if len(orig_text) != len(text):
logging.info(
diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh
index 3946908c6..edac0e8e6 100755
--- a/egs/commonvoice/ASR/prepare.sh
+++ b/egs/commonvoice/ASR/prepare.sh
@@ -36,8 +36,8 @@ num_splits=1000
# - speech
dl_dir=$PWD/download
-release=cv-corpus-13.0-2023-03-09
-lang=en
+release=cv-corpus-12.0-2022-12-07
+lang=fr
. shared/parse_options.sh || exit 1
@@ -146,7 +146,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then
./local/compute_fbank_commonvoice_splits.py \
--num-workers $nj \
- --batch-duration 600 \
+ --batch-duration 200 \
--start 0 \
--num-splits $num_splits \
--language $lang
@@ -189,7 +189,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
fi
-
+
if [ ! -f $lang_dir/words.txt ]; then
cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_dir/words.txt
@@ -216,14 +216,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
}' > $lang_dir/words || exit 1;
mv $lang_dir/words $lang_dir/words.txt
fi
-
+
if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
fi
-
+
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
@@ -250,3 +250,55 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
done
fi
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "Stage 10: Prepare G"
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir/lm
+ #3-gram used in building HLG, 4-gram used for LM rescoring
+ for ngram in 3 4; do
+ if [ ! -f $lang_dir/lm/${ngram}gram.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order ${ngram} \
+ -text $lang_dir/transcript_words.txt \
+ -lm $lang_dir/lm/${ngram}gram.arpa
+ fi
+
+ if [ ! -f $lang_dir/lm/${ngram}gram.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=${ngram} \
+ $lang_dir/lm/${ngram}gram.arpa > $lang_dir/lm/G_${ngram}_gram.fst.txt
+ fi
+ done
+ done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Compile HLG"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ ./local/compile_hlg.py --lang-dir $lang_dir
+
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
+ done
+fi
+
+# Compile LG for RNN-T fast_beam_search decoding
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "Stage 12: Compile LG"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/${lang}/lang_bpe_${vocab_size}
+ ./local/compile_lg.py --lang-dir $lang_dir
+ done
+fi
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
new file mode 100644
index 000000000..991875aaa
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
@@ -0,0 +1,9 @@
+This recipe implements Streaming Zipformer-Transducer model.
+
+See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer_transducer.html for detailed tutorials.
+
+[./emformer.py](./emformer.py) and [./train.py](./train.py)
+are basically the same as
+[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
+The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
+is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py
new file mode 120000
index 000000000..d7349b0a3
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
new file mode 100644
index 000000000..cafa4111d
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
@@ -0,0 +1,422 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SingleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class CommonVoiceAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--language",
+ type=str,
+ default="fr",
+ help="""Language of Common Voice""",
+ )
+ group.add_argument(
+ "--cv-manifest-dir",
+ type=Path,
+ default=Path("data/fr/fbank"),
+ help="Path to directory with CommonVoice train/dev/test cuts.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SingleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.cv_manifest_dir / f"cv-{self.args.language}_cuts_test.jsonl.gz"
+ )
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
new file mode 100755
index 000000000..30f7c1e77
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -0,0 +1,810 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from commonvoice_fr import CommonVoiceAsrDataModule
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ feature_lens += 30
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, 30),
+ value=LOG_EPS,
+ )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+
+ # errs_info = (
+ # params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ # )
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{key}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+ assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
+ model.encoder.decode_chunk_size,
+ params.decode_chunk_len,
+ )
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ test_cuts = commonvoice.test_cuts()
+
+ test_dl = commonvoice.test_dataloaders(test_cuts)
+
+ test_sets = "test-cv"
+
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_sets,
+ results_dict=results_dict,
+ )
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
new file mode 120000
index 000000000..ca8fed319
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py
new file mode 120000
index 000000000..33944d0d2
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/decoder.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
new file mode 120000
index 000000000..cb673b3eb
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
new file mode 120000
index 000000000..72e43c297
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
new file mode 120000
index 000000000..3b36924ef
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
new file mode 120000
index 000000000..57a0cd0a0
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py
new file mode 120000
index 000000000..2acafdc61
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
new file mode 100755
index 000000000..3a10c5d81
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
@@ -0,0 +1,1342 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --full-libri 1 \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7/exp \
+ --full-libri 1 \
+ --max-duration 550
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from commonvoice_fr import CommonVoiceAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ filter_uneven_sized_batch,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument("--do-finetune", type=str2bool, default=False)
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (a path to a .pt file)",
+ )
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="""Embedding dimension in the 2 blocks of zipformer encoder
+ layers, comma separated
+ """,
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers,\
+ comma separated; not the same as embedding dimension.
+ """,
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="""Unmasked dimensions in the encoders, relates to augmentation
+ during training. Must be <= each of encoder_dims. Empirically, less
+ than 256 seems to make performance worse.
+ """,
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="""Path to the BPE model.
+ This should be the bpe model of the original model
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.005, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. During fine-tuning, we set this very large so that the
+ learning rate slowly decays with number of batches. You may tune
+ its value by yourself.
+ """,
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100,
+ help="""Number of epochs that affects how rapidly the learning rate
+ decreases. During fine-tuning, we set this very large so that the
+ learning rate slowly decays with number of batches. You may tune
+ its value by yourself.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "frame_shift_ms": 10.0,
+ "allowed_excess_duration_ratio": 0.1,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [k for k in src_state_dict.keys() if k.startswith(module)]
+ dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ # For the uneven-sized batch, the total duration after padding would possibly
+ # cause OOM. Hence, for each batch, which is sorted descendingly by length,
+ # we simply drop the last few shortest samples, so that the retained total frames
+ # (after padding) would not exceed `allowed_max_frames`:
+ # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
+ # where `max_frames = max_duration * 1000 // frame_shift_ms`.
+ # We set allowed_excess_duration_ratio=0.1.
+ max_frames = params.max_duration * 1000 // params.frame_shift_ms
+ allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
+ batch = filter_uneven_sized_batch(batch, allowed_max_frames)
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have
+ # different behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ # load model parameters for model fine-tuning
+ if params.do_finetune:
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules
+ )
+ else:
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ train_cuts = commonvoice.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = commonvoice.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = commonvoice.dev_cuts()
+ valid_dl = commonvoice.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(
+ parser
+ ) # you may replace this with your own dataset
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py
new file mode 100755
index 000000000..3fd14aa47
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/generate_model_from_checkpoint.py
@@ -0,0 +1,281 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) use the averaged model with checkpoint exp_dir/epoch-xxx.pt
+./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
+ --epoch 28 \
+ --avg 15 \
+ --use-averaged-model True \
+ --exp-dir ./pruned_transducer_stateless7/exp
+
+It will generate a file `epoch-28-avg-15-use-averaged-model.pt` in the given `exp_dir`.
+You can later load it by `torch.load("epoch-28-avg-15-use-averaged-model.pt")`.
+
+(2) use the averaged model with checkpoint exp_dir/checkpoint-iter.pt
+./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
+ --iter 22000 \
+ --avg 5 \
+ --use-averaged-model True \
+ --exp-dir ./pruned_transducer_stateless7/exp
+
+It will generate a file `iter-22000-avg-5-use-averaged-model.pt` in the given `exp_dir`.
+You can later load it by `torch.load("iter-22000-avg-5-use-averaged-model.pt")`.
+
+(3) use the original model with checkpoint exp_dir/epoch-xxx.pt
+./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
+ --epoch 28 \
+ --avg 15 \
+ --use-averaged-model False \
+ --exp-dir ./pruned_transducer_stateless7/exp
+
+It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
+You can later load it by `torch.load("epoch-28-avg-15.pt")`.
+
+(4) use the original model with checkpoint exp_dir/checkpoint-iter.pt
+./pruned_transducer_stateless7/generate_model_from_checkpoint.py \
+ --iter 22000 \
+ --avg 5 \
+ --use-averaged-model False \
+ --exp-dir ./pruned_transducer_stateless7/exp
+
+It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
+You can later load it by `torch.load("iter-22000-avg-5.pt")`.
+"""
+
+
+import argparse
+from pathlib import Path
+from typing import Dict, List
+
+import sentencepiece as spm
+import torch
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model."
+ "If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ print("Script started")
+
+ device = torch.device("cpu")
+ print(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ print("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ print(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ print(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ print(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ filename = (
+ params.exp_dir
+ / f"iter-{params.iter}-avg-{params.avg}-use-averaged-model.pt"
+ )
+ torch.save({"model": model.state_dict()}, filename)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ print(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ filename = (
+ params.exp_dir
+ / f"epoch-{params.epoch}-avg-{params.avg}-use-averaged-model.pt"
+ )
+ torch.save({"model": model.state_dict()}, filename)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+ print("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
new file mode 120000
index 000000000..5d9c6ba00
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
new file mode 120000
index 000000000..457131699
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
new file mode 120000
index 000000000..2b8fa3cbb
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py
new file mode 120000
index 000000000..ecfb6dd8a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/joiner.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py
new file mode 120000
index 000000000..e17d4f734
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/model.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
new file mode 120000
index 000000000..28bf7bb82
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
new file mode 120000
index 000000000..c8548d459
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
new file mode 120000
index 000000000..ae4d9bb04
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py
new file mode 120000
index 000000000..9510b8fde
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
new file mode 120000
index 000000000..92c3904af
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
new file mode 120000
index 000000000..2adf271c1
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
new file mode 100755
index 000000000..dbe65d0a7
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -0,0 +1,612 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./pruned_transducer_stateless7_streaming/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --decode-chunk-len 32 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --decoding_method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from commonvoice_fr import CommonVoiceAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+from zipformer import stack_states, unstack_states
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless2/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = []
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
+ # factor in encoders is 8.
+ # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
+ tail_length = 23
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+ processed_lens = torch.tensor(processed_lens, device=device)
+
+ encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ states=states,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 50
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ idx = 0
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = model.encoder.get_init_state(device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+ audio: np.ndarray = cut.load_audio()
+ if audio.max() > 1 or audio.min() < -1:
+ audio = audio / max(abs(audio.max()), abs(audio.min()))
+ print(audio)
+ print(audio.max())
+ print(audio.min())
+ print(cut)
+ idx += 1
+ print(idx)
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ sp.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ sp.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if params.decoding_method == "greedy_search":
+ key = "greedy_search"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ # for streaming
+ params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+ test_cuts = commonvoice.test_cuts()
+ test_sets = "test-cv"
+
+ results_dict = decode_dataset(
+ cuts=test_cuts,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_sets,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
new file mode 100755
index 000000000..5400df804
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/test_model.py
@@ -0,0 +1,150 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+To run this file, do:
+
+ cd icefall/egs/librispeech/ASR
+ python ./pruned_transducer_stateless7_streaming/test_model.py
+"""
+
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from train import get_params, get_transducer_model
+
+
+def test_model():
+ params = get_params()
+ params.vocab_size = 500
+ params.blank_id = 0
+ params.context_size = 2
+ params.num_encoder_layers = "2,4,3,2,4"
+ params.feedforward_dims = "1024,1024,2048,2048,1024"
+ params.nhead = "8,8,8,8,8"
+ params.encoder_dims = "384,384,384,384,384"
+ params.attention_dims = "192,192,192,192,192"
+ params.encoder_unmasked_dims = "256,256,256,256,256"
+ params.zipformer_downsampling_factors = "1,2,4,8,2"
+ params.cnn_module_kernels = "31,31,31,31,31"
+ params.decoder_dim = 512
+ params.joiner_dim = 512
+ params.num_left_chunks = 4
+ params.short_chunk_size = 50
+ params.decode_chunk_len = 32
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+ # Test jit script
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ print("Using torch.jit.script")
+ model = torch.jit.script(model)
+
+
+def test_model_jit_trace():
+ params = get_params()
+ params.vocab_size = 500
+ params.blank_id = 0
+ params.context_size = 2
+ params.num_encoder_layers = "2,4,3,2,4"
+ params.feedforward_dims = "1024,1024,2048,2048,1024"
+ params.nhead = "8,8,8,8,8"
+ params.encoder_dims = "384,384,384,384,384"
+ params.attention_dims = "192,192,192,192,192"
+ params.encoder_unmasked_dims = "256,256,256,256,256"
+ params.zipformer_downsampling_factors = "1,2,4,8,2"
+ params.cnn_module_kernels = "31,31,31,31,31"
+ params.decoder_dim = 512
+ params.joiner_dim = 512
+ params.num_left_chunks = 4
+ params.short_chunk_size = 50
+ params.decode_chunk_len = 32
+ model = get_transducer_model(params)
+ model.eval()
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+ convert_scaled_to_non_scaled(model, inplace=True)
+
+ # Test encoder
+ def _test_encoder():
+ encoder = model.encoder
+ assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
+ encoder.decode_chunk_size,
+ params.decode_chunk_len,
+ )
+ T = params.decode_chunk_len + 7
+
+ x = torch.zeros(1, T, 80, dtype=torch.float32)
+ x_lens = torch.full((1,), T, dtype=torch.int32)
+ states = encoder.get_init_state(device=x.device)
+ encoder.__class__.forward = encoder.__class__.streaming_forward
+ traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
+
+ states1 = encoder.get_init_state(device=x.device)
+ states2 = traced_encoder.get_init_state(device=x.device)
+ for i in range(5):
+ x = torch.randn(1, T, 80, dtype=torch.float32)
+ x_lens = torch.full((1,), T, dtype=torch.int32)
+ y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
+ y2, _, states2 = traced_encoder(x, x_lens, states2)
+ assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
+
+ # Test decoder
+ def _test_decoder():
+ decoder = model.decoder
+ y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
+ need_pad = torch.tensor([False])
+
+ traced_decoder = torch.jit.trace(decoder, (y, need_pad))
+ d1 = decoder(y, need_pad)
+ d2 = traced_decoder(y, need_pad)
+ assert torch.equal(d1, d2), (d1 - d2).abs().mean()
+
+ # Test joiner
+ def _test_joiner():
+ joiner = model.joiner
+ encoder_out_dim = joiner.encoder_proj.weight.shape[1]
+ decoder_out_dim = joiner.decoder_proj.weight.shape[1]
+ encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
+ decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
+
+ traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
+ j1 = joiner(encoder_out, decoder_out)
+ j2 = traced_joiner(encoder_out, decoder_out)
+ assert torch.equal(j1, j2), (j1 - j2).abs().mean()
+
+ _test_encoder()
+ _test_decoder()
+ _test_joiner()
+
+
+def main():
+ test_model()
+ test_model_jit_trace()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
new file mode 100755
index 000000000..a9bc9c2a2
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -0,0 +1,1256 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 550
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from commonvoice_fr import CommonVoiceAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/fr/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ train_cuts = commonvoice.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = commonvoice.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = commonvoice.dev_cuts()
+ valid_dl = commonvoice.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
new file mode 100755
index 000000000..c09c9537c
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
@@ -0,0 +1,1257 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 550
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from commonvoice_fr import CommonVoiceAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer2 import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ is_pnnx=True,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ if "cur_batch_idx" in saved_params:
+ params["cur_batch_idx"] = saved_params["cur_batch_idx"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ commonvoice = CommonVoiceAsrDataModule(args)
+
+ train_cuts = commonvoice.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = commonvoice.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = commonvoice.dev_cuts()
+ valid_dl = commonvoice.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ CommonVoiceAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py
new file mode 120000
index 000000000..ec183baa7
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
new file mode 120000
index 000000000..12dbda888
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
\ No newline at end of file
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
deleted file mode 100755
index 29a2cd7f7..000000000
--- a/icefall/shared/convert-k2-to-openfst.py
+++ /dev/null
@@ -1,102 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script takes as input an FST in k2 format and convert it
-to an FST in OpenFST format.
-
-The generated FST is saved into a binary file and its type is
-StdVectorFst.
-
-Usage examples:
-(1) Convert an acceptor
-
- ./convert-k2-to-openfst.py in.pt binary.fst
-
-(2) Convert a transducer
-
- ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
-"""
-
-import argparse
-import logging
-from pathlib import Path
-
-import k2
-import kaldifst.utils
-import torch
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--olabels",
- type=str,
- default=None,
- help="""If not empty, the input FST is assumed to be a transducer
- and we use its attribute specified by "olabels" as the output labels.
- """,
- )
- parser.add_argument(
- "input_filename",
- type=str,
- help="Path to the input FST in k2 format",
- )
-
- parser.add_argument(
- "output_filename",
- type=str,
- help="Path to the output FST in OpenFst format",
- )
-
- return parser.parse_args()
-
-
-def main():
- args = get_args()
- logging.info(f"{vars(args)}")
-
- input_filename = args.input_filename
- output_filename = args.output_filename
- olabels = args.olabels
-
- if Path(output_filename).is_file():
- logging.info(f"{output_filename} already exists - skipping")
- return
-
- assert Path(input_filename).is_file(), f"{input_filename} does not exist"
- logging.info(f"Loading {input_filename}")
- k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
- if olabels:
- assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
-
- p = Path(output_filename).parent
- if not p.is_dir():
- logging.info(f"Creating {p}")
- p.mkdir(parents=True)
-
- logging.info("Converting (May take some time if the input FST is large)")
- fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
- logging.info(f"Saving to {output_filename}")
- fst.write(output_filename)
-
-
-if __name__ == "__main__":
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
- logging.basicConfig(format=formatter, level=logging.INFO)
- main()
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
new file mode 120000
index 000000000..24efe5eae
--- /dev/null
+++ b/icefall/shared/convert-k2-to-openfst.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/shared/convert-k2-to-openfst.py
\ No newline at end of file
diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py
deleted file mode 100755
index b1ebee9ea..000000000
--- a/icefall/shared/ngram_entropy_pruning.py
+++ /dev/null
@@ -1,630 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-#
-# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Usage:
-./ngram_entropy_pruning.py \
- -threshold 1e-8 \
- -lm download/lm/4gram.arpa \
- -write-lm download/lm/4gram_pruned_1e8.arpa
-
-This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
-This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
-in the same way as SRILM.
-"""
-
-
-import argparse
-import gzip
-import logging
-import math
-import re
-from collections import OrderedDict, defaultdict
-from enum import Enum, unique
-from io import StringIO
-
-parser = argparse.ArgumentParser(
- description="""
- Prune an n-gram language model based on the relative entropy
- between the original and the pruned model, based on Andreas Stolcke's paper.
- An n-gram entry is removed, if the removal causes (training set) perplexity
- of the model to increase by less than threshold relative.
-
- The command takes an arpa file and a pruning threshold as input,
- and outputs a pruned arpa file.
- """
-)
-parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
-parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
-parser.add_argument(
- "-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
-)
-parser.add_argument(
- "-minorder",
- type=int,
- default=1,
- help="The minorder parameter limits pruning to ngrams of that length and above.",
-)
-parser.add_argument(
- "-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
-)
-parser.add_argument(
- "-verbose",
- type=int,
- default=2,
- choices=[0, 1, 2, 3, 4, 5],
- help="Verbose level, where 0 is most noisy; 5 is most silent",
-)
-args = parser.parse_args()
-
-default_encoding = args.encoding
-logging.basicConfig(
- format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
- level=args.verbose * 10,
-)
-
-
-class Context(dict):
- """
- This class stores data for a context h.
- It behaves like a python dict object, except that it has several
- additional attributes.
- """
-
- def __init__(self):
- super().__init__()
- self.log_bo = None
-
-
-class Arpa:
- """
- This is a class that implement the data structure of an APRA LM.
- It (as well as some other classes) is modified based on the library
- by Stefan Fischer:
- https://github.com/sfischer13/python-arpa
- """
-
- UNK = ""
- SOS = ""
- EOS = ""
- FLOAT_NDIGITS = 7
- base = 10
-
- @staticmethod
- def _check_input(my_input):
- if not my_input:
- raise ValueError
- elif isinstance(my_input, tuple):
- return my_input
- elif isinstance(my_input, list):
- return tuple(my_input)
- elif isinstance(my_input, str):
- return tuple(my_input.strip().split(" "))
- else:
- raise ValueError
-
- @staticmethod
- def _check_word(input_word):
- if not isinstance(input_word, str):
- raise ValueError
- if " " in input_word:
- raise ValueError
-
- def _replace_unks(self, words):
- return tuple((w if w in self else self._unk) for w in words)
-
- def __init__(self, path=None, encoding=None, unk=None):
- self._counts = OrderedDict()
- self._ngrams = (
- OrderedDict()
- ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
- self._vocabulary = set()
- if unk is None:
- self._unk = self.UNK
-
- if path is not None:
- self.loadf(path, encoding)
-
- def __contains__(self, ngram):
- h = ngram[:-1] # h is a tuple
- w = ngram[-1] # w is a string/word
- return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
-
- def contains_word(self, word):
- self._check_word(word)
- return word in self._vocabulary
-
- def add_count(self, order, count):
- self._counts[order] = count
- self._ngrams[order - 1] = defaultdict(Context)
-
- def update_counts(self):
- for order in range(1, self.order() + 1):
- count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
- if count > 0:
- self._counts[order] = count
-
- def add_entry(self, ngram, p, bo=None, order=None):
- # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
- h = ngram[:-1] # h is a tuple
- w = ngram[-1] # w is a string/word
-
- # Note that p and bo here are in fact in the log domain (self.base = 10)
- h_context = self._ngrams[len(h)][h]
- h_context[w] = p
- if bo is not None:
- self._ngrams[len(ngram)][ngram].log_bo = bo
-
- for word in ngram:
- self._vocabulary.add(word)
-
- def counts(self):
- return sorted(self._counts.items())
-
- def order(self):
- return max(self._counts.keys(), default=None)
-
- def vocabulary(self, sort=True):
- if sort:
- return sorted(self._vocabulary)
- else:
- return self._vocabulary
-
- def _entries(self, order):
- return (
- self._entry(h, w)
- for h, wlist in self._ngrams[order - 1].items()
- for w in wlist
- )
-
- def _entry(self, h, w):
- # return the entry for the ngram (h, w)
- ngram = h + (w,)
- log_p = self._ngrams[len(h)][h][w]
- log_bo = self._log_bo(ngram)
- if log_bo is not None:
- return (
- round(log_p, self.FLOAT_NDIGITS),
- ngram,
- round(log_bo, self.FLOAT_NDIGITS),
- )
- else:
- return round(log_p, self.FLOAT_NDIGITS), ngram
-
- def _log_bo(self, ngram):
- if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
- return self._ngrams[len(ngram)][ngram].log_bo
- else:
- return None
-
- def _log_p(self, ngram):
- h = ngram[:-1] # h is a tuple
- w = ngram[-1] # w is a string/word
- if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
- return self._ngrams[len(h)][h][w]
- else:
- return None
-
- def log_p_raw(self, ngram):
- log_p = self._log_p(ngram)
- if log_p is not None:
- return log_p
- else:
- if len(ngram) == 1:
- raise KeyError
- else:
- log_bo = self._log_bo(ngram[:-1])
- if log_bo is None:
- log_bo = 0
- return log_bo + self.log_p_raw(ngram[1:])
-
- def log_joint_prob(self, sequence):
- # Compute the joint prob of the sequence based on the chain rule
- # Note that sequence should be a tuple of strings
- #
- # Reference:
- # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
-
- log_joint_p = 0
- seq = sequence
- while len(seq) > 0:
- log_joint_p += self.log_p_raw(seq)
- seq = seq[:-1]
-
- # If we're computing the marginal probability of the unigram
- # context we have to look up instead since the former
- # has prob = 0.
- if len(seq) == 1 and seq[0] == self.SOS:
- seq = (self.EOS,)
-
- return log_joint_p
-
- def set_new_context(self, h):
- old_context = self._ngrams[len(h)][h]
- self._ngrams[len(h)][h] = Context()
- return old_context
-
- def log_p(self, ngram):
- words = self._check_input(ngram)
- if self._unk:
- words = self._replace_unks(words)
- return self.log_p_raw(words)
-
- def log_s(self, sentence, sos=SOS, eos=EOS):
- words = self._check_input(sentence)
- if self._unk:
- words = self._replace_unks(words)
- if sos:
- words = (sos,) + words
- if eos:
- words = words + (eos,)
- result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
- if sos:
- result = result - self.log_p_raw(words[:1])
- return result
-
- def p(self, ngram):
- return self.base ** self.log_p(ngram)
-
- def s(self, sentence):
- return self.base ** self.log_s(sentence)
-
- def write(self, fp):
- fp.write("\n\\data\\\n")
- for order, count in self.counts():
- fp.write("ngram {}={}\n".format(order, count))
- fp.write("\n")
- for order, _ in self.counts():
- fp.write("\\{}-grams:\n".format(order))
- for e in self._entries(order):
- prob = e[0]
- ngram = " ".join(e[1])
- if len(e) == 2:
- fp.write("{}\t{}\n".format(prob, ngram))
- elif len(e) == 3:
- backoff = e[2]
- fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
- else:
- raise ValueError
- fp.write("\n")
- fp.write("\\end\\\n")
-
-
-class ArpaParser:
- """
- This is a class that implement a parser of an arpa file
- """
-
- @unique
- class State(Enum):
- DATA = 1
- COUNT = 2
- HEADER = 3
- ENTRY = 4
-
- re_count = re.compile(r"^ngram (\d+)=(\d+)$")
- re_header = re.compile(r"^\\(\d+)-grams:$")
- re_entry = re.compile(
- "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
- "\t"
- "(\\S+( \\S+)*)"
- "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
- )
-
- def _parse(self, fp):
- self._result = []
- self._state = self.State.DATA
- self._tmp_model = None
- self._tmp_order = None
- for line in fp:
- line = line.strip()
- if self._state == self.State.DATA:
- self._data(line)
- elif self._state == self.State.COUNT:
- self._count(line)
- elif self._state == self.State.HEADER:
- self._header(line)
- elif self._state == self.State.ENTRY:
- self._entry(line)
- if self._state != self.State.DATA:
- raise Exception(line)
- return self._result
-
- def _data(self, line):
- if line == "\\data\\":
- self._state = self.State.COUNT
- self._tmp_model = Arpa()
- else:
- pass # skip comment line
-
- def _count(self, line):
- match = self.re_count.match(line)
- if match:
- order = match.group(1)
- count = match.group(2)
- self._tmp_model.add_count(int(order), int(count))
- elif not line:
- self._state = self.State.HEADER # there are no counts
- else:
- raise Exception(line)
-
- def _header(self, line):
- match = self.re_header.match(line)
- if match:
- self._state = self.State.ENTRY
- self._tmp_order = int(match.group(1))
- elif line == "\\end\\":
- self._result.append(self._tmp_model)
- self._state = self.State.DATA
- self._tmp_model = None
- self._tmp_order = None
- elif not line:
- pass # skip empty line
- else:
- raise Exception(line)
-
- def _entry(self, line):
- match = self.re_entry.match(line)
- if match:
- p = self._float_or_int(match.group(1))
- ngram = tuple(match.group(4).split(" "))
- bo_match = match.group(7)
- bo = self._float_or_int(bo_match) if bo_match else None
- self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
- elif not line:
- self._state = self.State.HEADER # last entry
- else:
- raise Exception(line)
-
- @staticmethod
- def _float_or_int(s):
- f = float(s)
- i = int(f)
- if str(i) == s: # don't drop trailing ".0"
- return i
- else:
- return f
-
- def load(self, fp):
- """Deserialize fp (a file-like object) to a Python object."""
- return self._parse(fp)
-
- def loadf(self, path, encoding=None):
- """Deserialize path (.arpa, .gz) to a Python object."""
- path = str(path)
- if path.endswith(".gz"):
- with gzip.open(path, mode="rt", encoding=encoding) as f:
- return self.load(f)
- else:
- with open(path, mode="rt", encoding=encoding) as f:
- return self.load(f)
-
- def loads(self, s):
- """Deserialize s (a str) to a Python object."""
- with StringIO(s) as f:
- return self.load(f)
-
- def dump(self, obj, fp):
- """Serialize obj to fp (a file-like object) in ARPA format."""
- obj.write(fp)
-
- def dumpf(self, obj, path, encoding=None):
- """Serialize obj to path in ARPA format (.arpa, .gz)."""
- path = str(path)
- if path.endswith(".gz"):
- with gzip.open(path, mode="wt", encoding=encoding) as f:
- return self.dump(obj, f)
- else:
- with open(path, mode="wt", encoding=encoding) as f:
- self.dump(obj, f)
-
- def dumps(self, obj):
- """Serialize obj to an ARPA formatted str."""
- with StringIO() as f:
- self.dump(obj, f)
- return f.getvalue()
-
-
-def add_log_p(prev_log_sum, log_p, base):
- return math.log(base**log_p + base**prev_log_sum, base)
-
-
-def compute_numerator_denominator(lm, h):
- log_sum_seen_h = -math.inf
- log_sum_seen_h_lower = -math.inf
- base = lm.base
- for w, log_p in lm._ngrams[len(h)][h].items():
- log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
-
- ngram = h + (w,)
- log_p_lower = lm.log_p_raw(ngram[1:])
- log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
-
- numerator = 1.0 - base**log_sum_seen_h
- denominator = 1.0 - base**log_sum_seen_h_lower
- return numerator, denominator
-
-
-def prune(lm, threshold, minorder):
- # Reference:
- # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
-
- for i in range(
- lm.order(), max(minorder - 1, 1), -1
- ): # i is the order of the ngram (h, w)
- logging.info("processing %d-grams ..." % i)
- count_pruned_ngrams = 0
-
- h_dict = lm._ngrams[i - 1]
- for h in list(h_dict.keys()):
- # old backoff weight, BOW(h)
- log_bow = lm._log_bo(h)
- if log_bow is None:
- log_bow = 0
-
- # Compute numerator and denominator of the backoff weight,
- # so that we can quickly compute the BOW adjustment due to
- # leaving out one prob.
- numerator, denominator = compute_numerator_denominator(lm, h)
-
- # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
-
- # Compute the marginal probability of the context, P(h)
- h_log_p = lm.log_joint_prob(h)
-
- all_pruned = True
- pruned_w_set = set()
-
- for w, log_p in h_dict[h].items():
- ngram = h + (w,)
-
- # lower-order estimate for ngramProb, P(w|h')
- backoff_prob = lm.log_p_raw(ngram[1:])
-
- # Compute BOW after removing ngram, BOW'(h)
- new_log_bow = math.log(
- numerator + lm.base**log_p, lm.base
- ) - math.log(denominator + lm.base**backoff_prob, lm.base)
-
- # Compute change in entropy due to removal of ngram
- delta_prob = backoff_prob + new_log_bow - log_p
- delta_entropy = -(lm.base**h_log_p) * (
- (lm.base**log_p) * delta_prob
- + numerator * (new_log_bow - log_bow)
- )
-
- # compute relative change in model (training set) perplexity
- perp_change = lm.base**delta_entropy - 1.0
-
- pruned = threshold > 0 and perp_change < threshold
-
- # Make sure we don't prune ngrams whose backoff nodes are needed
- if (
- pruned
- and len(ngram) in lm._ngrams
- and len(lm._ngrams[len(ngram)][ngram]) > 0
- ):
- pruned = False
-
- logging.debug(
- "CONTEXT "
- + str(h)
- + " WORD "
- + w
- + " CONTEXTPROB %f " % h_log_p
- + " OLDPROB %f " % log_p
- + " NEWPROB %f " % (backoff_prob + new_log_bow)
- + " DELTA-H %f " % delta_entropy
- + " DELTA-LOGP %f " % delta_prob
- + " PPL-CHANGE %f " % perp_change
- + " PRUNED "
- + str(pruned)
- )
-
- if pruned:
- pruned_w_set.add(w)
- count_pruned_ngrams += 1
- else:
- all_pruned = False
-
- # If we removed all ngrams for this context we can
- # remove the context itself, but only if the present
- # context is not a prefix to a longer one.
- if all_pruned and len(pruned_w_set) == len(h_dict[h]):
- del h_dict[
- h
- ] # this context h is no longer needed, as its ngram prob is stored at its own context h'
- elif len(pruned_w_set) > 0:
- # The pruning for this context h is actually done here
- old_context = lm.set_new_context(h)
-
- for w, p_w in old_context.items():
- if w not in pruned_w_set:
- lm.add_entry(
- h + (w,), p_w
- ) # the entry hw is stored at the context h
-
- # We need to recompute the back-off weight, but
- # this can only be done after completing the pruning
- # of the lower-order ngrams.
- # Reference:
- # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
-
- logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
-
- # recompute backoff weights
- for i in range(
- max(minorder - 1, 1) + 1, lm.order() + 1
- ): # be careful of this order: from low- to high-order
- for h in lm._ngrams[i - 1]:
- numerator, denominator = compute_numerator_denominator(lm, h)
- new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
- lm._ngrams[len(h)][h].log_bo = new_log_bow
-
- # update counts
- lm.update_counts()
-
- return
-
-
-def check_h_is_valid(lm, h):
- sum_under_h = sum(
- [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
- )
- if abs(sum_under_h - 1.0) > 1e-6:
- logging.info("warning: %s %f" % (str(h), sum_under_h))
- return False
- else:
- return True
-
-
-def validate_lm(lm):
- # sanity check if the conditional probability sums to one under each context h
- for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
- logging.info("validating %d-grams ..." % i)
- h_dict = lm._ngrams[i - 1]
- for h in h_dict.keys():
- check_h_is_valid(lm, h)
-
-
-def compare_two_apras(path1, path2):
- pass
-
-
-if __name__ == "__main__":
- # load an arpa file
- logging.info("Loading the arpa file from %s" % args.lm)
- parser = ArpaParser()
- models = parser.loadf(args.lm, encoding=default_encoding)
- lm = models[0] # ARPA files may contain several models.
- logging.info("Stats before pruning:")
- for i, cnt in lm.counts():
- logging.info("ngram %d=%d" % (i, cnt))
-
- # prune it, the language model will be modified in-place
- logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
- prune(lm, args.threshold, args.minorder)
-
- # validate_lm(lm)
-
- # write the arpa language model to a file
- logging.info("Stats after pruning:")
- for i, cnt in lm.counts():
- logging.info("ngram %d=%d" % (i, cnt))
- logging.info("Saving the pruned arpa file to %s" % args.write_lm)
- parser.dumpf(lm, args.write_lm, encoding=default_encoding)
- logging.info("Done.")
diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py
new file mode 120000
index 000000000..0e14ac415
--- /dev/null
+++ b/icefall/shared/ngram_entropy_pruning.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/shared/ngram_entropy_pruning.py
\ No newline at end of file
diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh
deleted file mode 100755
index 71fb9e5ea..000000000
--- a/icefall/shared/parse_options.sh
+++ /dev/null
@@ -1,97 +0,0 @@
-#!/usr/bin/env bash
-
-# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
-# Arnab Ghoshal, Karel Vesely
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# Parse command-line options.
-# To be sourced by another script (as in ". parse_options.sh").
-# Option format is: --option-name arg
-# and shell variable "option_name" gets set to value "arg."
-# The exception is --help, which takes no arguments, but prints the
-# $help_message variable (if defined).
-
-
-###
-### The --config file options have lower priority to command line
-### options, so we need to import them first...
-###
-
-# Now import all the configs specified by command-line, in left-to-right order
-for ((argpos=1; argpos<$#; argpos++)); do
- if [ "${!argpos}" == "--config" ]; then
- argpos_plus1=$((argpos+1))
- config=${!argpos_plus1}
- [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
- . $config # source the config file.
- fi
-done
-
-
-###
-### Now we process the command line options
-###
-while true; do
- [ -z "${1:-}" ] && break; # break if there are no arguments
- case "$1" in
- # If the enclosing script is called with --help option, print the help
- # message and exit. Scripts should put help messages in $help_message
- --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
- else printf "$help_message\n" 1>&2 ; fi;
- exit 0 ;;
- --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
- exit 1 ;;
- # If the first command-line argument begins with "--" (e.g. --foo-bar),
- # then work out the variable name as $name, which will equal "foo_bar".
- --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
- # Next we test whether the variable in question is undefned-- if so it's
- # an invalid option and we die. Note: $0 evaluates to the name of the
- # enclosing script.
- # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
- # is undefined. We then have to wrap this test inside "eval" because
- # foo_bar is itself inside a variable ($name).
- eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
-
- oldval="`eval echo \\$$name`";
- # Work out whether we seem to be expecting a Boolean argument.
- if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
- was_bool=true;
- else
- was_bool=false;
- fi
-
- # Set the variable to the right value-- the escaped quotes make it work if
- # the option had spaces, like --cmd "queue.pl -sync y"
- eval $name=\"$2\";
-
- # Check that Boolean-valued arguments are really Boolean.
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
- exit 1;
- fi
- shift 2;
- ;;
- *) break;
- esac
-done
-
-
-# Check for an empty argument to the --cmd option, which can easily occur as a
-# result of scripting errors.
-[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
-
-
-true; # so this script returns exit code 0.
diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh
new file mode 120000
index 000000000..e4665e7de
--- /dev/null
+++ b/icefall/shared/parse_options.sh
@@ -0,0 +1 @@
+../../../librispeech/ASR/shared/parse_options.sh
\ No newline at end of file
From 6d275ddf9fdd67a32b79b93d70fedffe4b156d5c Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Fri, 10 Nov 2023 14:45:16 +0800
Subject: [PATCH 002/123] fixed broken softlinks (#1381)
* removed broken softlinks
* fixed dependencies
* fixed file permission
---
icefall/shared/convert-k2-to-openfst.py | 103 +++-
icefall/shared/ngram_entropy_pruning.py | 631 +++++++++++++++++++++++-
icefall/shared/parse_options.sh | 98 +++-
3 files changed, 829 insertions(+), 3 deletions(-)
mode change 120000 => 100755 icefall/shared/convert-k2-to-openfst.py
mode change 120000 => 100755 icefall/shared/ngram_entropy_pruning.py
mode change 120000 => 100755 icefall/shared/parse_options.sh
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
deleted file mode 120000
index 24efe5eae..000000000
--- a/icefall/shared/convert-k2-to-openfst.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/shared/convert-k2-to-openfst.py
\ No newline at end of file
diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py
new file mode 100755
index 000000000..29a2cd7f7
--- /dev/null
+++ b/icefall/shared/convert-k2-to-openfst.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script takes as input an FST in k2 format and convert it
+to an FST in OpenFST format.
+
+The generated FST is saved into a binary file and its type is
+StdVectorFst.
+
+Usage examples:
+(1) Convert an acceptor
+
+ ./convert-k2-to-openfst.py in.pt binary.fst
+
+(2) Convert a transducer
+
+ ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import kaldifst.utils
+import torch
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--olabels",
+ type=str,
+ default=None,
+ help="""If not empty, the input FST is assumed to be a transducer
+ and we use its attribute specified by "olabels" as the output labels.
+ """,
+ )
+ parser.add_argument(
+ "input_filename",
+ type=str,
+ help="Path to the input FST in k2 format",
+ )
+
+ parser.add_argument(
+ "output_filename",
+ type=str,
+ help="Path to the output FST in OpenFst format",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ logging.info(f"{vars(args)}")
+
+ input_filename = args.input_filename
+ output_filename = args.output_filename
+ olabels = args.olabels
+
+ if Path(output_filename).is_file():
+ logging.info(f"{output_filename} already exists - skipping")
+ return
+
+ assert Path(input_filename).is_file(), f"{input_filename} does not exist"
+ logging.info(f"Loading {input_filename}")
+ k2_fst = k2.Fsa.from_dict(torch.load(input_filename))
+ if olabels:
+ assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}"
+
+ p = Path(output_filename).parent
+ if not p.is_dir():
+ logging.info(f"Creating {p}")
+ p.mkdir(parents=True)
+
+ logging.info("Converting (May take some time if the input FST is large)")
+ fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels)
+ logging.info(f"Saving to {output_filename}")
+ fst.write(output_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py
deleted file mode 120000
index 0e14ac415..000000000
--- a/icefall/shared/ngram_entropy_pruning.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/shared/ngram_entropy_pruning.py
\ No newline at end of file
diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py
new file mode 100755
index 000000000..b1ebee9ea
--- /dev/null
+++ b/icefall/shared/ngram_entropy_pruning.py
@@ -0,0 +1,630 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+./ngram_entropy_pruning.py \
+ -threshold 1e-8 \
+ -lm download/lm/4gram.arpa \
+ -write-lm download/lm/4gram_pruned_1e8.arpa
+
+This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
+This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
+in the same way as SRILM.
+"""
+
+
+import argparse
+import gzip
+import logging
+import math
+import re
+from collections import OrderedDict, defaultdict
+from enum import Enum, unique
+from io import StringIO
+
+parser = argparse.ArgumentParser(
+ description="""
+ Prune an n-gram language model based on the relative entropy
+ between the original and the pruned model, based on Andreas Stolcke's paper.
+ An n-gram entry is removed, if the removal causes (training set) perplexity
+ of the model to increase by less than threshold relative.
+
+ The command takes an arpa file and a pruning threshold as input,
+ and outputs a pruned arpa file.
+ """
+)
+parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
+parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
+parser.add_argument(
+ "-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
+)
+parser.add_argument(
+ "-minorder",
+ type=int,
+ default=1,
+ help="The minorder parameter limits pruning to ngrams of that length and above.",
+)
+parser.add_argument(
+ "-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
+)
+parser.add_argument(
+ "-verbose",
+ type=int,
+ default=2,
+ choices=[0, 1, 2, 3, 4, 5],
+ help="Verbose level, where 0 is most noisy; 5 is most silent",
+)
+args = parser.parse_args()
+
+default_encoding = args.encoding
+logging.basicConfig(
+ format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
+ level=args.verbose * 10,
+)
+
+
+class Context(dict):
+ """
+ This class stores data for a context h.
+ It behaves like a python dict object, except that it has several
+ additional attributes.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.log_bo = None
+
+
+class Arpa:
+ """
+ This is a class that implement the data structure of an APRA LM.
+ It (as well as some other classes) is modified based on the library
+ by Stefan Fischer:
+ https://github.com/sfischer13/python-arpa
+ """
+
+ UNK = ""
+ SOS = ""
+ EOS = ""
+ FLOAT_NDIGITS = 7
+ base = 10
+
+ @staticmethod
+ def _check_input(my_input):
+ if not my_input:
+ raise ValueError
+ elif isinstance(my_input, tuple):
+ return my_input
+ elif isinstance(my_input, list):
+ return tuple(my_input)
+ elif isinstance(my_input, str):
+ return tuple(my_input.strip().split(" "))
+ else:
+ raise ValueError
+
+ @staticmethod
+ def _check_word(input_word):
+ if not isinstance(input_word, str):
+ raise ValueError
+ if " " in input_word:
+ raise ValueError
+
+ def _replace_unks(self, words):
+ return tuple((w if w in self else self._unk) for w in words)
+
+ def __init__(self, path=None, encoding=None, unk=None):
+ self._counts = OrderedDict()
+ self._ngrams = (
+ OrderedDict()
+ ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
+ self._vocabulary = set()
+ if unk is None:
+ self._unk = self.UNK
+
+ if path is not None:
+ self.loadf(path, encoding)
+
+ def __contains__(self, ngram):
+ h = ngram[:-1] # h is a tuple
+ w = ngram[-1] # w is a string/word
+ return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
+
+ def contains_word(self, word):
+ self._check_word(word)
+ return word in self._vocabulary
+
+ def add_count(self, order, count):
+ self._counts[order] = count
+ self._ngrams[order - 1] = defaultdict(Context)
+
+ def update_counts(self):
+ for order in range(1, self.order() + 1):
+ count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
+ if count > 0:
+ self._counts[order] = count
+
+ def add_entry(self, ngram, p, bo=None, order=None):
+ # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
+ h = ngram[:-1] # h is a tuple
+ w = ngram[-1] # w is a string/word
+
+ # Note that p and bo here are in fact in the log domain (self.base = 10)
+ h_context = self._ngrams[len(h)][h]
+ h_context[w] = p
+ if bo is not None:
+ self._ngrams[len(ngram)][ngram].log_bo = bo
+
+ for word in ngram:
+ self._vocabulary.add(word)
+
+ def counts(self):
+ return sorted(self._counts.items())
+
+ def order(self):
+ return max(self._counts.keys(), default=None)
+
+ def vocabulary(self, sort=True):
+ if sort:
+ return sorted(self._vocabulary)
+ else:
+ return self._vocabulary
+
+ def _entries(self, order):
+ return (
+ self._entry(h, w)
+ for h, wlist in self._ngrams[order - 1].items()
+ for w in wlist
+ )
+
+ def _entry(self, h, w):
+ # return the entry for the ngram (h, w)
+ ngram = h + (w,)
+ log_p = self._ngrams[len(h)][h][w]
+ log_bo = self._log_bo(ngram)
+ if log_bo is not None:
+ return (
+ round(log_p, self.FLOAT_NDIGITS),
+ ngram,
+ round(log_bo, self.FLOAT_NDIGITS),
+ )
+ else:
+ return round(log_p, self.FLOAT_NDIGITS), ngram
+
+ def _log_bo(self, ngram):
+ if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
+ return self._ngrams[len(ngram)][ngram].log_bo
+ else:
+ return None
+
+ def _log_p(self, ngram):
+ h = ngram[:-1] # h is a tuple
+ w = ngram[-1] # w is a string/word
+ if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
+ return self._ngrams[len(h)][h][w]
+ else:
+ return None
+
+ def log_p_raw(self, ngram):
+ log_p = self._log_p(ngram)
+ if log_p is not None:
+ return log_p
+ else:
+ if len(ngram) == 1:
+ raise KeyError
+ else:
+ log_bo = self._log_bo(ngram[:-1])
+ if log_bo is None:
+ log_bo = 0
+ return log_bo + self.log_p_raw(ngram[1:])
+
+ def log_joint_prob(self, sequence):
+ # Compute the joint prob of the sequence based on the chain rule
+ # Note that sequence should be a tuple of strings
+ #
+ # Reference:
+ # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
+
+ log_joint_p = 0
+ seq = sequence
+ while len(seq) > 0:
+ log_joint_p += self.log_p_raw(seq)
+ seq = seq[:-1]
+
+ # If we're computing the marginal probability of the unigram
+ # context we have to look up instead since the former
+ # has prob = 0.
+ if len(seq) == 1 and seq[0] == self.SOS:
+ seq = (self.EOS,)
+
+ return log_joint_p
+
+ def set_new_context(self, h):
+ old_context = self._ngrams[len(h)][h]
+ self._ngrams[len(h)][h] = Context()
+ return old_context
+
+ def log_p(self, ngram):
+ words = self._check_input(ngram)
+ if self._unk:
+ words = self._replace_unks(words)
+ return self.log_p_raw(words)
+
+ def log_s(self, sentence, sos=SOS, eos=EOS):
+ words = self._check_input(sentence)
+ if self._unk:
+ words = self._replace_unks(words)
+ if sos:
+ words = (sos,) + words
+ if eos:
+ words = words + (eos,)
+ result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
+ if sos:
+ result = result - self.log_p_raw(words[:1])
+ return result
+
+ def p(self, ngram):
+ return self.base ** self.log_p(ngram)
+
+ def s(self, sentence):
+ return self.base ** self.log_s(sentence)
+
+ def write(self, fp):
+ fp.write("\n\\data\\\n")
+ for order, count in self.counts():
+ fp.write("ngram {}={}\n".format(order, count))
+ fp.write("\n")
+ for order, _ in self.counts():
+ fp.write("\\{}-grams:\n".format(order))
+ for e in self._entries(order):
+ prob = e[0]
+ ngram = " ".join(e[1])
+ if len(e) == 2:
+ fp.write("{}\t{}\n".format(prob, ngram))
+ elif len(e) == 3:
+ backoff = e[2]
+ fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
+ else:
+ raise ValueError
+ fp.write("\n")
+ fp.write("\\end\\\n")
+
+
+class ArpaParser:
+ """
+ This is a class that implement a parser of an arpa file
+ """
+
+ @unique
+ class State(Enum):
+ DATA = 1
+ COUNT = 2
+ HEADER = 3
+ ENTRY = 4
+
+ re_count = re.compile(r"^ngram (\d+)=(\d+)$")
+ re_header = re.compile(r"^\\(\d+)-grams:$")
+ re_entry = re.compile(
+ "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
+ "\t"
+ "(\\S+( \\S+)*)"
+ "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
+ )
+
+ def _parse(self, fp):
+ self._result = []
+ self._state = self.State.DATA
+ self._tmp_model = None
+ self._tmp_order = None
+ for line in fp:
+ line = line.strip()
+ if self._state == self.State.DATA:
+ self._data(line)
+ elif self._state == self.State.COUNT:
+ self._count(line)
+ elif self._state == self.State.HEADER:
+ self._header(line)
+ elif self._state == self.State.ENTRY:
+ self._entry(line)
+ if self._state != self.State.DATA:
+ raise Exception(line)
+ return self._result
+
+ def _data(self, line):
+ if line == "\\data\\":
+ self._state = self.State.COUNT
+ self._tmp_model = Arpa()
+ else:
+ pass # skip comment line
+
+ def _count(self, line):
+ match = self.re_count.match(line)
+ if match:
+ order = match.group(1)
+ count = match.group(2)
+ self._tmp_model.add_count(int(order), int(count))
+ elif not line:
+ self._state = self.State.HEADER # there are no counts
+ else:
+ raise Exception(line)
+
+ def _header(self, line):
+ match = self.re_header.match(line)
+ if match:
+ self._state = self.State.ENTRY
+ self._tmp_order = int(match.group(1))
+ elif line == "\\end\\":
+ self._result.append(self._tmp_model)
+ self._state = self.State.DATA
+ self._tmp_model = None
+ self._tmp_order = None
+ elif not line:
+ pass # skip empty line
+ else:
+ raise Exception(line)
+
+ def _entry(self, line):
+ match = self.re_entry.match(line)
+ if match:
+ p = self._float_or_int(match.group(1))
+ ngram = tuple(match.group(4).split(" "))
+ bo_match = match.group(7)
+ bo = self._float_or_int(bo_match) if bo_match else None
+ self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
+ elif not line:
+ self._state = self.State.HEADER # last entry
+ else:
+ raise Exception(line)
+
+ @staticmethod
+ def _float_or_int(s):
+ f = float(s)
+ i = int(f)
+ if str(i) == s: # don't drop trailing ".0"
+ return i
+ else:
+ return f
+
+ def load(self, fp):
+ """Deserialize fp (a file-like object) to a Python object."""
+ return self._parse(fp)
+
+ def loadf(self, path, encoding=None):
+ """Deserialize path (.arpa, .gz) to a Python object."""
+ path = str(path)
+ if path.endswith(".gz"):
+ with gzip.open(path, mode="rt", encoding=encoding) as f:
+ return self.load(f)
+ else:
+ with open(path, mode="rt", encoding=encoding) as f:
+ return self.load(f)
+
+ def loads(self, s):
+ """Deserialize s (a str) to a Python object."""
+ with StringIO(s) as f:
+ return self.load(f)
+
+ def dump(self, obj, fp):
+ """Serialize obj to fp (a file-like object) in ARPA format."""
+ obj.write(fp)
+
+ def dumpf(self, obj, path, encoding=None):
+ """Serialize obj to path in ARPA format (.arpa, .gz)."""
+ path = str(path)
+ if path.endswith(".gz"):
+ with gzip.open(path, mode="wt", encoding=encoding) as f:
+ return self.dump(obj, f)
+ else:
+ with open(path, mode="wt", encoding=encoding) as f:
+ self.dump(obj, f)
+
+ def dumps(self, obj):
+ """Serialize obj to an ARPA formatted str."""
+ with StringIO() as f:
+ self.dump(obj, f)
+ return f.getvalue()
+
+
+def add_log_p(prev_log_sum, log_p, base):
+ return math.log(base**log_p + base**prev_log_sum, base)
+
+
+def compute_numerator_denominator(lm, h):
+ log_sum_seen_h = -math.inf
+ log_sum_seen_h_lower = -math.inf
+ base = lm.base
+ for w, log_p in lm._ngrams[len(h)][h].items():
+ log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
+
+ ngram = h + (w,)
+ log_p_lower = lm.log_p_raw(ngram[1:])
+ log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
+
+ numerator = 1.0 - base**log_sum_seen_h
+ denominator = 1.0 - base**log_sum_seen_h_lower
+ return numerator, denominator
+
+
+def prune(lm, threshold, minorder):
+ # Reference:
+ # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
+
+ for i in range(
+ lm.order(), max(minorder - 1, 1), -1
+ ): # i is the order of the ngram (h, w)
+ logging.info("processing %d-grams ..." % i)
+ count_pruned_ngrams = 0
+
+ h_dict = lm._ngrams[i - 1]
+ for h in list(h_dict.keys()):
+ # old backoff weight, BOW(h)
+ log_bow = lm._log_bo(h)
+ if log_bow is None:
+ log_bow = 0
+
+ # Compute numerator and denominator of the backoff weight,
+ # so that we can quickly compute the BOW adjustment due to
+ # leaving out one prob.
+ numerator, denominator = compute_numerator_denominator(lm, h)
+
+ # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
+
+ # Compute the marginal probability of the context, P(h)
+ h_log_p = lm.log_joint_prob(h)
+
+ all_pruned = True
+ pruned_w_set = set()
+
+ for w, log_p in h_dict[h].items():
+ ngram = h + (w,)
+
+ # lower-order estimate for ngramProb, P(w|h')
+ backoff_prob = lm.log_p_raw(ngram[1:])
+
+ # Compute BOW after removing ngram, BOW'(h)
+ new_log_bow = math.log(
+ numerator + lm.base**log_p, lm.base
+ ) - math.log(denominator + lm.base**backoff_prob, lm.base)
+
+ # Compute change in entropy due to removal of ngram
+ delta_prob = backoff_prob + new_log_bow - log_p
+ delta_entropy = -(lm.base**h_log_p) * (
+ (lm.base**log_p) * delta_prob
+ + numerator * (new_log_bow - log_bow)
+ )
+
+ # compute relative change in model (training set) perplexity
+ perp_change = lm.base**delta_entropy - 1.0
+
+ pruned = threshold > 0 and perp_change < threshold
+
+ # Make sure we don't prune ngrams whose backoff nodes are needed
+ if (
+ pruned
+ and len(ngram) in lm._ngrams
+ and len(lm._ngrams[len(ngram)][ngram]) > 0
+ ):
+ pruned = False
+
+ logging.debug(
+ "CONTEXT "
+ + str(h)
+ + " WORD "
+ + w
+ + " CONTEXTPROB %f " % h_log_p
+ + " OLDPROB %f " % log_p
+ + " NEWPROB %f " % (backoff_prob + new_log_bow)
+ + " DELTA-H %f " % delta_entropy
+ + " DELTA-LOGP %f " % delta_prob
+ + " PPL-CHANGE %f " % perp_change
+ + " PRUNED "
+ + str(pruned)
+ )
+
+ if pruned:
+ pruned_w_set.add(w)
+ count_pruned_ngrams += 1
+ else:
+ all_pruned = False
+
+ # If we removed all ngrams for this context we can
+ # remove the context itself, but only if the present
+ # context is not a prefix to a longer one.
+ if all_pruned and len(pruned_w_set) == len(h_dict[h]):
+ del h_dict[
+ h
+ ] # this context h is no longer needed, as its ngram prob is stored at its own context h'
+ elif len(pruned_w_set) > 0:
+ # The pruning for this context h is actually done here
+ old_context = lm.set_new_context(h)
+
+ for w, p_w in old_context.items():
+ if w not in pruned_w_set:
+ lm.add_entry(
+ h + (w,), p_w
+ ) # the entry hw is stored at the context h
+
+ # We need to recompute the back-off weight, but
+ # this can only be done after completing the pruning
+ # of the lower-order ngrams.
+ # Reference:
+ # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
+
+ logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
+
+ # recompute backoff weights
+ for i in range(
+ max(minorder - 1, 1) + 1, lm.order() + 1
+ ): # be careful of this order: from low- to high-order
+ for h in lm._ngrams[i - 1]:
+ numerator, denominator = compute_numerator_denominator(lm, h)
+ new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
+ lm._ngrams[len(h)][h].log_bo = new_log_bow
+
+ # update counts
+ lm.update_counts()
+
+ return
+
+
+def check_h_is_valid(lm, h):
+ sum_under_h = sum(
+ [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
+ )
+ if abs(sum_under_h - 1.0) > 1e-6:
+ logging.info("warning: %s %f" % (str(h), sum_under_h))
+ return False
+ else:
+ return True
+
+
+def validate_lm(lm):
+ # sanity check if the conditional probability sums to one under each context h
+ for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
+ logging.info("validating %d-grams ..." % i)
+ h_dict = lm._ngrams[i - 1]
+ for h in h_dict.keys():
+ check_h_is_valid(lm, h)
+
+
+def compare_two_apras(path1, path2):
+ pass
+
+
+if __name__ == "__main__":
+ # load an arpa file
+ logging.info("Loading the arpa file from %s" % args.lm)
+ parser = ArpaParser()
+ models = parser.loadf(args.lm, encoding=default_encoding)
+ lm = models[0] # ARPA files may contain several models.
+ logging.info("Stats before pruning:")
+ for i, cnt in lm.counts():
+ logging.info("ngram %d=%d" % (i, cnt))
+
+ # prune it, the language model will be modified in-place
+ logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
+ prune(lm, args.threshold, args.minorder)
+
+ # validate_lm(lm)
+
+ # write the arpa language model to a file
+ logging.info("Stats after pruning:")
+ for i, cnt in lm.counts():
+ logging.info("ngram %d=%d" % (i, cnt))
+ logging.info("Saving the pruned arpa file to %s" % args.write_lm)
+ parser.dumpf(lm, args.write_lm, encoding=default_encoding)
+ logging.info("Done.")
diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh
deleted file mode 120000
index e4665e7de..000000000
--- a/icefall/shared/parse_options.sh
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/shared/parse_options.sh
\ No newline at end of file
diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh
new file mode 100755
index 000000000..71fb9e5ea
--- /dev/null
+++ b/icefall/shared/parse_options.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
+# Arnab Ghoshal, Karel Vesely
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parse command-line options.
+# To be sourced by another script (as in ". parse_options.sh").
+# Option format is: --option-name arg
+# and shell variable "option_name" gets set to value "arg."
+# The exception is --help, which takes no arguments, but prints the
+# $help_message variable (if defined).
+
+
+###
+### The --config file options have lower priority to command line
+### options, so we need to import them first...
+###
+
+# Now import all the configs specified by command-line, in left-to-right order
+for ((argpos=1; argpos<$#; argpos++)); do
+ if [ "${!argpos}" == "--config" ]; then
+ argpos_plus1=$((argpos+1))
+ config=${!argpos_plus1}
+ [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
+ . $config # source the config file.
+ fi
+done
+
+
+###
+### Now we process the command line options
+###
+while true; do
+ [ -z "${1:-}" ] && break; # break if there are no arguments
+ case "$1" in
+ # If the enclosing script is called with --help option, print the help
+ # message and exit. Scripts should put help messages in $help_message
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
+ else printf "$help_message\n" 1>&2 ; fi;
+ exit 0 ;;
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
+ exit 1 ;;
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
+ # then work out the variable name as $name, which will equal "foo_bar".
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
+ # Next we test whether the variable in question is undefned-- if so it's
+ # an invalid option and we die. Note: $0 evaluates to the name of the
+ # enclosing script.
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
+ # is undefined. We then have to wrap this test inside "eval" because
+ # foo_bar is itself inside a variable ($name).
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
+
+ oldval="`eval echo \\$$name`";
+ # Work out whether we seem to be expecting a Boolean argument.
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
+ was_bool=true;
+ else
+ was_bool=false;
+ fi
+
+ # Set the variable to the right value-- the escaped quotes make it work if
+ # the option had spaces, like --cmd "queue.pl -sync y"
+ eval $name=\"$2\";
+
+ # Check that Boolean-valued arguments are really Boolean.
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
+ exit 1;
+ fi
+ shift 2;
+ ;;
+ *) break;
+ esac
+done
+
+
+# Check for an empty argument to the --cmd option, which can easily occur as a
+# result of scripting errors.
+[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
+
+
+true; # so this script returns exit code 0.
From 59c943878ff7f3d741a29d743b8560e342fa892d Mon Sep 17 00:00:00 2001
From: Karel Vesely
Date: Thu, 16 Nov 2023 07:38:31 +0100
Subject: [PATCH 003/123] add the `voxpopuli` recipe (#1374)
* add the `voxpopuli` recipe
- this is the data preparation
- there is no ASR training and no results
* update the PR#1374 (feedback from @csukuangfj)
- fixing .py headers and docstrings
- removing BUT specific parts of `prepare.sh`
- adding assert `num_jobs >= num_workers` to `compute_fbank.py`
- narrowing list of languages
(let's limit to ASR sets with transcripts for now)
- added links to `README.md`
- extending `text_from_manifest.py`
---
egs/voxpopuli/ASR/README.md | 38 +++
egs/voxpopuli/ASR/local/compute_fbank.py | 248 +++++++++++++++++
.../ASR/local/compute_fbank_musan.py | 1 +
.../ASR/local/display_manifest_statistics.py | 56 ++++
.../duration_from_supervision_manifest.py | 93 +++++++
egs/voxpopuli/ASR/local/filter_cuts.py | 1 +
egs/voxpopuli/ASR/local/prepare_lang_bpe.py | 1 +
.../ASR/local/preprocess_voxpopuli.py | 178 ++++++++++++
.../ASR/local/separate_punctuation.py | 130 +++++++++
egs/voxpopuli/ASR/local/text_from_manifest.py | 54 ++++
egs/voxpopuli/ASR/local/train_bpe_model.py | 1 +
.../ASR/local/uppercase_begin_of_sentence.py | 113 ++++++++
.../ASR/local/validate_bpe_lexicon.py | 1 +
.../ASR/local/validate_cutset_manifest.py | 123 +++++++++
egs/voxpopuli/ASR/prepare.sh | 257 ++++++++++++++++++
egs/voxpopuli/ASR/shared | 1 +
16 files changed, 1296 insertions(+)
create mode 100644 egs/voxpopuli/ASR/README.md
create mode 100755 egs/voxpopuli/ASR/local/compute_fbank.py
create mode 120000 egs/voxpopuli/ASR/local/compute_fbank_musan.py
create mode 100755 egs/voxpopuli/ASR/local/display_manifest_statistics.py
create mode 100755 egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py
create mode 120000 egs/voxpopuli/ASR/local/filter_cuts.py
create mode 120000 egs/voxpopuli/ASR/local/prepare_lang_bpe.py
create mode 100755 egs/voxpopuli/ASR/local/preprocess_voxpopuli.py
create mode 100755 egs/voxpopuli/ASR/local/separate_punctuation.py
create mode 100755 egs/voxpopuli/ASR/local/text_from_manifest.py
create mode 120000 egs/voxpopuli/ASR/local/train_bpe_model.py
create mode 100755 egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py
create mode 120000 egs/voxpopuli/ASR/local/validate_bpe_lexicon.py
create mode 100755 egs/voxpopuli/ASR/local/validate_cutset_manifest.py
create mode 100755 egs/voxpopuli/ASR/prepare.sh
create mode 120000 egs/voxpopuli/ASR/shared
diff --git a/egs/voxpopuli/ASR/README.md b/egs/voxpopuli/ASR/README.md
new file mode 100644
index 000000000..92aa26464
--- /dev/null
+++ b/egs/voxpopuli/ASR/README.md
@@ -0,0 +1,38 @@
+# Readme
+
+This recipe contains data preparation for the
+[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset
+[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf).
+At the moment, without model training.
+
+
+## audio per language
+
+| language | Size | Hrs. untranscribed | Hrs. transcribed |
+|----------|--------|--------------------|------------------|
+| bg | 295G | 17.6K | - |
+| cs | 308G | 18.7K | 62 |
+| da | 233G | 13.6K | - |
+| de | 379G | 23.2K | 282 |
+| el | 305G | 17.7K | - |
+| en | 382G | 24.1K | 543 |
+| es | 362G | 21.4K | 166 |
+| et | 179G | 10.6K | 3 |
+| fi | 236G | 14.2K | 27 |
+| fr | 376G | 22.8K | 211 |
+| hr | 132G | 8.1K | 43 |
+| hu | 297G | 17.7K | 63 |
+| it | 361G | 21.9K | 91 |
+| lt | 243G | 14.4K | 2 |
+| lv | 217G | 13.1K | - |
+| mt | 147G | 9.1K | - |
+| nl | 322G | 19.0K | 53 |
+| pl | 348G | 21.2K | 111 |
+| pt | 300G | 17.5K | - |
+| ro | 296G | 17.9K | 89 |
+| sk | 201G | 12.1K | 35 |
+| sl | 190G | 11.3K | 10 |
+| sv | 272G | 16.3K | - |
+| | | | |
+| total | 6.3T | 384K | 1791 |
+
diff --git a/egs/voxpopuli/ASR/local/compute_fbank.py b/egs/voxpopuli/ASR/local/compute_fbank.py
new file mode 100755
index 000000000..b63e51f29
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/compute_fbank.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of VoxPopuli dataset.
+
+Usage example:
+
+ python3 ./local/compute_fbank.py \
+ --src-dir data/fbank --output-dir data/fbank \
+ --num-jobs 100 --num-workers 25 \
+ --prefix "voxpopuli-${task}-${lang}" \
+ --dataset train \
+ --trim-to-supervisions True \
+ --speed-perturb True
+
+It looks for raw CutSet in the directory data/fbank
+located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`.
+
+The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats`
+and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`.
+
+Typically, the number of workers is smaller than number of jobs
+(see --num-jobs 100 --num-workers 25 in the example).
+And, the number of jobs should be at least the number of workers (it's checked).
+"""
+
+import argparse
+import logging
+import multiprocessing
+import os
+from concurrent.futures import ProcessPoolExecutor
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from filter_cuts import filter_cuts
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ is_caching_enabled,
+ set_caching_enabled,
+)
+
+from icefall.utils import str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to the bpe.model. If not None, we will remove short and
+ long utterances before extracting features""",
+ )
+ parser.add_argument(
+ "--src-dir",
+ type=str,
+ help="""Folder with the input manifest files.""",
+ default="data/manifests",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ help="""Folder with the output manifests (cuts) and feature files.""",
+ default="data/fbank",
+ )
+
+ parser.add_argument(
+ "--prefix",
+ type=str,
+ help="""Prefix of the manifest files.""",
+ default="",
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank (train,test,dev).""",
+ default=None,
+ )
+
+ parser.add_argument(
+ "--num-jobs",
+ type=int,
+ help="""Number of jobs (i.e. files with extracted features)""",
+ default=50,
+ )
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ help="""Number of parallel workers""",
+ default=10,
+ )
+ parser.add_argument(
+ "--speed-perturb",
+ type=str2bool,
+ default=False,
+ help="""Enable speed perturbation for the set.""",
+ )
+ parser.add_argument(
+ "--trim-to-supervisions",
+ type=str2bool,
+ default=False,
+ help="""Apply `trim-to-supervision` to cut set.""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_features(args: argparse.Namespace):
+ set_caching_enabled(True) # lhotse
+
+ src_dir = Path(args.src_dir)
+ output_dir = Path(args.output_dir)
+ num_jobs = args.num_jobs
+ num_workers = min(args.num_workers, os.cpu_count())
+ num_mel_bins = 80
+
+ bpe_model = args.bpe_model
+ if bpe_model:
+ logging.info(f"Loading {bpe_model}")
+ sp = spm.SentencePieceProcessor()
+ sp.load(bpe_model)
+
+ prefix = args.prefix # "ELEF_TRAIN"
+ dataset = args.dataset
+ suffix = "jsonl.gz"
+
+ cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}")
+ cuts_raw = CutSet.from_file(cuts_raw_filename)
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}")
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{output_dir/cuts_filename} already exists - skipping.")
+ return
+
+ logging.info(f"Processing {output_dir/cuts_filename}")
+ cut_set = cuts_raw
+
+ if bpe_model:
+ cut_set = filter_cuts(cut_set, sp)
+
+ if args.speed_perturb:
+ cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+
+ if args.trim_to_supervisions:
+ logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}")
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
+ else:
+ logging.info(
+ "Not doing `trim_to_supervisions()`, "
+ "to enable use --trim-to-supervision=True"
+ )
+
+ cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it)
+ cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate
+
+ # We typically use `num_jobs=100, num_workers=20`
+ # - this is helpful for large databases
+ # - both values are configurable externally
+ assert num_jobs >= num_workers, (num_jobs, num_workers)
+ executor = ProcessPoolExecutor(
+ max_workers=num_workers,
+ mp_context=multiprocessing.get_context("spawn"),
+ initializer=set_caching_enabled,
+ initargs=(is_caching_enabled(),),
+ )
+
+ logging.info(
+ f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}"
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir / prefix}-{dataset}_feats",
+ num_jobs=num_jobs,
+ executor=executor,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ # correct small deviations of duration, caused by speed-perturbation
+ for cut in cut_set:
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id)
+ duration_difference = abs(cut.supervisions[0].duration - cut.duration)
+ tolerance = 0.02 # 20ms
+ if duration_difference == 0.0:
+ pass
+ elif duration_difference <= tolerance:
+ logging.info(
+ "small mismatch of the supervision duration "
+ f"(Δt = {duration_difference*1000}ms), "
+ f"correcting : cut.duration {cut.duration} -> "
+ f"supervision {cut.supervisions[0].duration}"
+ )
+ cut.supervisions[0].duration = cut.duration
+ else:
+ logging.error(
+ "mismatch of cut/supervision duration "
+ f"(Δt = {duration_difference*1000}ms) : "
+ f"cut.duration {cut.duration}, "
+ f"supervision {cut.supervisions[0].duration}"
+ )
+ raise ValueError(
+ "mismatch of cut/supervision duration "
+ f"(Δt = {duration_difference*1000}ms)"
+ )
+
+ # store the cutset
+ logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`")
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ logging.info(vars(args))
+
+ compute_fbank_features(args)
diff --git a/egs/voxpopuli/ASR/local/compute_fbank_musan.py b/egs/voxpopuli/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/display_manifest_statistics.py b/egs/voxpopuli/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..36c99e126
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+Usage example:
+ python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz
+
+See the function `remove_short_and_long_utt()` in transducer/train.py
+for usage.
+
+"""
+
+import argparse
+
+from lhotse import load_manifest_lazy
+
+
+def get_args():
+ parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz")
+
+ parser.add_argument(
+ "filename",
+ help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ cuts = load_manifest_lazy(args.filename)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py
new file mode 100755
index 000000000..957267fe8
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script computes durations of datasets from
+the SupervisionSet manifests.
+
+Usage example:
+
+ python3 ./local/duration_from_supervision_manifest.py \
+ data/manifest/*_superivions*.jsonl.gz
+"""
+
+import argparse
+import gzip
+import json
+import logging
+import re
+import sys
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ "Read the raw text from the 'supervisions.jsonl.gz'"
+ )
+
+ parser.add_argument(
+ "filename",
+ help="supervisions.jsonl.gz",
+ nargs="+",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ logging.info(vars(args))
+
+ total_duration = 0.0
+ total_n_utts = 0
+
+ for fname in args.filename:
+ if fname == "-":
+ fd = sys.stdin
+ elif re.match(r".*\.jsonl\.gz$", fname):
+ fd = gzip.open(fname, mode="r")
+ else:
+ fd = open(fname, mode="r")
+
+ fname_duration = 0.0
+ n_utts = 0
+ for line in fd:
+ js = json.loads(line)
+ fname_duration += js["duration"]
+ n_utts += 1
+
+ print(
+ f"Duration: {fname_duration/3600:7.2f} hours "
+ f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}"
+ )
+
+ if fd != sys.stdin:
+ fd.close()
+
+ total_duration += fname_duration
+ total_n_utts += n_utts
+
+ print(
+ f"Total duration: {total_duration/3600:7.2f} hours "
+ f"(eq. {total_duration:7.0f} seconds)"
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/voxpopuli/ASR/local/filter_cuts.py b/egs/voxpopuli/ASR/local/filter_cuts.py
new file mode 120000
index 000000000..27aca1729
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/filter_cuts.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/filter_cuts.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py
new file mode 100755
index 000000000..4032537db
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py
@@ -0,0 +1,178 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
+# 2023 Brno University of Technology (author: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Preprocess the database.
+- Convert RecordingSet and SupervisionSet to CutSet.
+- Apply text normalization to the transcripts.
+ - We take renormalized `orig_text` as `text` transcripts.
+ - The text normalization is separating punctuation from words.
+ - Also we put capital letter to the beginning of a sentence.
+
+The script is inspired in:
+ `egs/commonvoice/ASR/local/preprocess_commonvoice.py`
+
+Usage example:
+ python3 ./local/preprocess_voxpopuli.py \
+ --task asr --lang en
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Optional
+
+from lhotse import CutSet
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# from local/
+from separate_punctuation import separate_punctuation
+from uppercase_begin_of_sentence import UpperCaseBeginOfSentence
+
+from icefall.utils import str2bool
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ default=None,
+ )
+
+ parser.add_argument(
+ "--task",
+ type=str,
+ help="""Task of VoxPopuli""",
+ default="asr",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ help="""Language of VoxPopuli""",
+ required=True,
+ )
+
+ parser.add_argument(
+ "--use-original-text",
+ type=str2bool,
+ help="""Use 'original_text' from the annoattaion file,
+ otherwise 'normed_text' will be used
+ (see `data/manifests/${task}_${lang}.tsv.gz`).
+ """,
+ default=False,
+ )
+
+ return parser.parse_args()
+
+
+def normalize_text(utt: str) -> str:
+ utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt))
+ return utt
+
+
+def preprocess_voxpopuli(
+ task: str,
+ language: str,
+ dataset: Optional[str] = None,
+ use_original_text: bool = False,
+):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ output_dir.mkdir(exist_ok=True)
+
+ if dataset is None:
+ dataset_parts = (
+ "dev",
+ "test",
+ "train",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ logging.info("Loading manifest")
+ prefix = f"voxpopuli-{task}-{language}"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ suffix=suffix,
+ prefix=prefix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ for partition, m in manifests.items():
+ logging.info(f"Processing {partition}")
+ raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
+ if raw_cuts_path.is_file():
+ logging.info(f"{partition} already exists - skipping")
+ continue
+
+ if use_original_text:
+ logging.info("Using 'original_text' from the annotation file.")
+ logging.info(f"Normalizing text in {partition}")
+ for sup in m["supervisions"]:
+ # `orig_text` includes punctuation and true-case
+ orig_text = str(sup.custom["orig_text"])
+ # we replace `text` by normalized `orig_text`
+ sup.text = normalize_text(orig_text)
+ else:
+ logging.info("Using 'normed_text' from the annotation file.")
+
+ # remove supervisions with empty 'text'
+ m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0)
+
+ # Create cut manifest with long-recordings.
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ ).resample(16000)
+
+ # Store the cut set incl. the resampling.
+ logging.info(f"Saving to {raw_cuts_path}")
+ cut_set.to_file(raw_cuts_path)
+
+
+def main():
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ preprocess_voxpopuli(
+ task=args.task,
+ language=args.lang,
+ dataset=args.dataset,
+ use_original_text=args.use_original_text,
+ )
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/separate_punctuation.py b/egs/voxpopuli/ASR/local/separate_punctuation.py
new file mode 100755
index 000000000..706d6fcd5
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/separate_punctuation.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script chops the punctuation as standalone tokens.
+Example:
+ input: "This is fine. Yes, you are right."
+ output: "This is fine . Yes , you are right ."
+
+The script also handles exceptions in a hard-coded fashion.
+
+(same functionality could be done with `nltk.tokenize.word_tokenize()`,
+ but that would be an extra dependency)
+
+It can be used as a module, or as an executable script.
+
+Usage example #1:
+ `from separate_punctuation import separate_punctuation`
+
+Usage example #2:
+```
+ python3 ./local/separate_punctuation.py \
+ --ignore-columns 1 \
+ < ${kaldi_data}/text
+```
+"""
+
+import re
+import sys
+from argparse import ArgumentParser
+
+
+def separate_punctuation(text: str) -> str:
+ """
+ Text filtering function for separating punctuation.
+
+ Example:
+ input: "This is fine. Yes, you are right."
+ output: "This is fine . Yes , you are right ."
+
+ The exceptions for which the punctuation is
+ not splitted are hard-coded.
+ """
+
+ # remove non-desired punctuation symbols
+ text = re.sub('["„“«»]', "", text)
+
+ # separate [,.!?;] punctuation from words by space
+ text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text)
+ text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text)
+
+ # split to tokens
+ tokens = text.split()
+ tokens_out = []
+
+ # re-join the special cases of punctuation
+ for ii, tok in enumerate(tokens):
+ # no rewriting for 1st and last token
+ if ii > 0 and ii < len(tokens) - 1:
+ # **RULES ADDED FOR CZECH COMMON VOICE**
+
+ # fix "27 . dubna" -> "27. dubna", but keep punctuation separate,
+ if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower():
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ # fix "resp . pak" -> "resp. pak"
+ if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower():
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ # **RULES ADDED FOR ENGLISH COMMON VOICE**
+
+ # fix "A ." -> "A."
+ if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]):
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ # fix "Mr ." -> "Mr."
+ exceptions = set(["Mr", "Mrs", "Ms"])
+ if tok == "." and tokens[ii - 1] in exceptions:
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ tokens_out.append(tok)
+
+ return " ".join(tokens_out)
+
+
+def get_args():
+ parser = ArgumentParser(
+ description="Separate punctuation from words: 'hello.' -> 'hello .'"
+ )
+ parser.add_argument(
+ "--ignore-columns", type=int, default=1, help="skip number of initial columns"
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ max_split = args.ignore_columns
+
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+
+ *key, text = line.strip().split(maxsplit=max_split)
+ text_norm = separate_punctuation(text)
+
+ print(" ".join(key), text_norm)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/text_from_manifest.py b/egs/voxpopuli/ASR/local/text_from_manifest.py
new file mode 100755
index 000000000..d9ab53b5a
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/text_from_manifest.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`.
+
+Usage example:
+ python3 ./local/text_from_manifest.py \
+ data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz
+"""
+
+import argparse
+import gzip
+import json
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ "Read the raw text from the 'supervisions.jsonl.gz'"
+ )
+ parser.add_argument("filename", help="supervisions.jsonl.gz")
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ with gzip.open(args.filename, mode="r") as fd:
+ for line in fd:
+ js = json.loads(line)
+ if "text" in js:
+ print(js["text"]) # supervisions.jsonl.gz
+ elif "supervisions" in js:
+ for s in js["supervisions"]:
+ print(s["text"]) # cuts.jsonl.gz
+ else:
+ raise Exception(f"Unknown jsonl format of {args.filename}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/train_bpe_model.py b/egs/voxpopuli/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py
new file mode 100755
index 000000000..8e9de905f
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script introduces initial capital letter at the beginning of a sentence.
+It can be used as a module, or as an executable script.
+
+Usage example #1:
+ `from uppercase_begin_of_sentence import UpperCaseBeginOfSentence`
+
+Usage example #2:
+```
+ python3 ./local/uppercase_begin_of_sentence.py \
+ --ignore-columns 1 \
+ < ${kaldi_data}/text
+```
+"""
+
+import re
+import sys
+from argparse import ArgumentParser
+
+
+class UpperCaseBeginOfSentence:
+ """
+ This class introduces initial capital letter at the beginning of a sentence.
+ Capital letter is used, if previous symbol was punctuation token from
+ `set([".", "!", "?"])`.
+
+ The punctuation as previous token is memorized also across
+ `process_line_text()` calls.
+ """
+
+ def __init__(self):
+ # The 1st word will have Title-case
+ # This variable transfers context from previous line
+ self.prev_token_is_punct = True
+
+ def process_line_text(self, line_text: str) -> str:
+ """
+ It is assumed that punctuation in `line_text` was already separated,
+ example: "This is fine . Yes , you are right ."
+ """
+
+ words = line_text.split()
+ punct_set = set([".", "!", "?"])
+
+ for ii, w in enumerate(words):
+ # punctuation ?
+ if w in punct_set:
+ self.prev_token_is_punct = True
+ continue
+
+ # change case of word...
+ if self.prev_token_is_punct:
+ if re.match("<", w):
+ continue # skip
+ # apply Title-case only on lowercase words.
+ if w.islower():
+ words[ii] = w.title()
+ # change state
+ self.prev_token_is_punct = False
+
+ line_text_uc = " ".join(words)
+
+ return line_text_uc
+
+
+def get_args():
+ parser = ArgumentParser(
+ description="Put upper-case at the beginning of a sentence."
+ )
+ parser.add_argument(
+ "--ignore-columns", type=int, default=4, help="skip number of initial columns"
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ uc_bos = UpperCaseBeginOfSentence()
+ max_split = args.ignore_columns
+
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+ line = line.strip()
+
+ if len(line.split()) > 1:
+ *key, text = line.strip().split(maxsplit=max_split) # parse,
+ text_uc = uc_bos.process_line_text(text) # process,
+ print(" ".join(key), text_uc) # print,
+ else:
+ print(line)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py
new file mode 100755
index 000000000..4659aa9cd
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+- Supervision time bounds are within Cut time bounds
+- Duration of Cut and Superivion are equal
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
+
+(Based on: `librispeech/ASR/local/validate_manifest.py`)
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset.speech_recognition import validate_for_asr
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "cutset_manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def validate_one_supervision_per_cut(c: Cut):
+ if len(c.supervisions) != 1:
+ raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
+
+
+def validate_supervision_and_cut_time_bounds(c: Cut):
+ tol = 2e-3 # same tolerance as in 'validate_for_asr()'
+ s = c.supervisions[0]
+
+ # Supervision start time is relative to Cut ...
+ # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
+ if s.start < -tol:
+ raise ValueError(
+ f"{c.id}: Supervision start time {s.start} must not be negative."
+ )
+ if s.start > tol:
+ raise ValueError(
+ f"{c.id}: Supervision start time {s.start} "
+ "is not at the beginning of the Cut. "
+ "Please apply `lhotse cut trim-to-supervisions`."
+ )
+ if c.start + s.end > c.end + tol:
+ raise ValueError(
+ f"{c.id}: Supervision end time {c.start+s.end} is larger "
+ f"than cut end time {c.end}"
+ )
+
+ if s.duration != c.duration:
+ raise ValueError(
+ f"{c.id}: Cut duration {c.duration} and supervision duration "
+ f"{s.duration} must be the same.\n"
+ f"The difference causes problems in the training code : "
+ f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n"
+ f"Did you forget to apply `trim_to_supervisions()` ?"
+ )
+
+
+def main():
+ args = get_args()
+
+ manifest = args.cutset_manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet)
+
+ try:
+ for c in cut_set:
+ validate_one_supervision_per_cut(c)
+ validate_supervision_and_cut_time_bounds(c)
+
+ # Validation from K2 training
+ # - checks supervision start is 0
+ # - checks supervision.duration is not longer than cut.duration
+ # - there is tolerance 2ms
+ validate_for_asr(cut_set)
+ except BaseException as e:
+ logging.error(str(e))
+ raise
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/voxpopuli/ASR/prepare.sh b/egs/voxpopuli/ASR/prepare.sh
new file mode 100755
index 000000000..7cddad756
--- /dev/null
+++ b/egs/voxpopuli/ASR/prepare.sh
@@ -0,0 +1,257 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -euxo pipefail
+
+nj=20
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/voxpopuli/raw_audios/$lang/$year
+# This directory contains *.ogg files with audio downloaded and extracted from archives:
+# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar
+#
+# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder
+# as part of `lhotse prepare voxpopuli` from:
+# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT
+
+musan_dir=${dl_dir}/musan
+#musan_dir=/mnt/matylda2/data/MUSAN # BUT
+
+# Choose value from ASR_LANGUAGES:
+#
+# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr",
+# "sk", "sl", "et", "lt" ]
+#
+# See ASR_LANGUAGES in:
+# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4
+lang=en
+
+task=asr
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/${lang}/lang_bpe_xxx,
+# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ # 5000
+ # 2000
+ # 1000
+ 500
+)
+
+# All files generated by this script are saved in "data/${lang}".
+# You can safely remove "data/${lang}" and rerun this script to regenerate it.
+mkdir -p data/${lang}
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+log "musan_dir: $musan_dir"
+log "task: $task, lang: $lang"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/$release,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/$release $dl_dir/$release
+ #
+ if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then
+ lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $musan_dir/musan ]; then
+ lhotse download musan $musan_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare VoxPopuli manifest"
+ # We assume that you have downloaded the VoxPopuli corpus
+ # to $dl_dir/voxpopuli
+ if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then
+ # Warning : it requires Internet connection (it downloads transcripts to ${tmpdir})
+ lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests
+ touch data/manifests/.voxpopuli-${task}-${lang}.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.musan.done ]; then
+ #lhotse prepare musan $dl_dir/musan data/manifests
+ lhotse prepare musan $musan_dir/musan data/manifests
+ touch data/manifests/.musan.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Preprocess VoxPopuli manifest"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then
+ # recordings + supervisions -> cutset
+ ./local/preprocess_voxpopuli.py --task $task --lang $lang \
+ --use-original-text True
+ touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli"
+ mkdir -p data/fbank
+ for dataset in "dev" "test"; do
+ if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then
+ ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
+ --num-jobs 50 --num-workers ${nj} \
+ --prefix "voxpopuli-${task}-${lang}" \
+ --dataset ${dataset} \
+ --trim-to-supervisions True
+ touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done
+ fi
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for train set of VoxPopuli"
+ if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then
+ ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
+ --num-jobs 100 --num-workers ${nj} \
+ --prefix "voxpopuli-${task}-${lang}" \
+ --dataset train \
+ --trim-to-supervisions True \
+ --speed-perturb True
+ touch data/fbank/.voxpopuli-${task}-${lang}-train.done
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Validate fbank manifests for VoxPopuli"
+ for dataset in "dev" "test" "train"; do
+ mkdir -p data/fbank/log/
+ ./local/validate_cutset_manifest.py \
+ data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \
+ 2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log
+ done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compute fbank for musan"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.musan.done ]; then
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.musan.done
+ fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}_${lang}
+ mkdir -p $lang_dir
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ log "Generate data for BPE training"
+ file=$(
+ find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz"
+ )
+ local/text_from_manifest.py $file >$lang_dir/transcript_words.txt
+ # gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
+
+ # Ensure space only appears once
+ #sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
+ #sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/words.txt ]; then
+ cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' > $lang_dir/words.txt
+ (echo '!SIL'; echo ''; echo ''; ) |
+ cat - $lang_dir/words.txt | sort | uniq | awk '
+ BEGIN {
+ print " 0";
+ }
+ {
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ printf("%s %d\n", $1, NR);
+ }
+ END {
+ printf("#0 %d\n", NR+1);
+ printf(" %d\n", NR+2);
+ printf(" %d\n", NR+3);
+ }' > $lang_dir/words || exit 1;
+ mv $lang_dir/words $lang_dir/words.txt
+ fi
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bpe.model
+ fi
+
+ if [ ! -f $lang_dir/L.fst ]; then
+ log "Converting L.pt to L.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L.pt \
+ $lang_dir/L.fst
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.fst ]; then
+ log "Converting L_disambig.pt to L_disambig.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L_disambig.pt \
+ $lang_dir/L_disambig.fst
+ fi
+ done
+fi
diff --git a/egs/voxpopuli/ASR/shared b/egs/voxpopuli/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/voxpopuli/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
From 666d69b20d53796420593d99b0c0d6e9cd2212cc Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Fri, 17 Nov 2023 18:12:59 +0800
Subject: [PATCH 004/123] Rename train2.py to avoid confusion (#1386)
---
.github/scripts/run-multi-zh_hans-zipformer.sh | 4 +++-
egs/aishell/ASR/prepare.sh | 5 ++---
.../{train2.py => do_not_use_it_directly.py} | 1 +
egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py | 2 +-
.../{train2.py => do_not_use_it_directly.py} | 1 +
.../ASR/pruned_transducer_stateless7_streaming/README.md | 4 ++--
.../{train2.py => do_not_use_it_directly.py} | 1 +
.../{train2.py => do_not_use_it_directly.py} | 1 +
.../export-for-ncnn.py | 2 +-
.../{train2.py => do_not_use_it_directly.py} | 1 +
.../conv_emformer_transducer_stateless2/export-for-ncnn.py | 2 +-
.../ASR/conv_emformer_transducer_stateless2/export-onnx.py | 2 +-
.../ASR/pruned_transducer_stateless7_streaming/README.md | 4 ++--
.../{train2.py => do_not_use_it_directly.py} | 1 +
.../export-for-ncnn-zh.py | 2 +-
.../export-for-ncnn.py | 2 +-
.../do_not_use_it_directly.py | 1 +
.../export-for-ncnn.py | 2 +-
.../pruned_transducer_stateless7_streaming_multi/train2.py | 1 -
19 files changed, 23 insertions(+), 16 deletions(-)
rename egs/aishell/ASR/pruned_transducer_stateless7/{train2.py => do_not_use_it_directly.py} (99%)
rename egs/aishell/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%)
rename egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%)
rename egs/csj/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%)
rename egs/librispeech/ASR/conv_emformer_transducer_stateless2/{train2.py => do_not_use_it_directly.py} (99%)
rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{train2.py => do_not_use_it_directly.py} (99%)
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py
delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py
diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh
index dd32a94f8..cbd86a4d3 100755
--- a/.github/scripts/run-multi-zh_hans-zipformer.sh
+++ b/.github/scripts/run-multi-zh_hans-zipformer.sh
@@ -51,6 +51,8 @@ for method in modified_beam_search fast_beam_search; do
$repo/test_wavs/DEV_T0000000002.wav
done
+rm -rf $repo
+
log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ===="
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
@@ -92,4 +94,4 @@ for method in modified_beam_search fast_beam_search; do
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
-done
\ No newline at end of file
+done
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index d36dc5ed3..9f73a2073 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -261,10 +261,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ ! -f $lang_char_dir/HLG.fst ]; then
- lang_phone_dir=data/lang_phone
./local/prepare_lang_fst.py \
- --lang-dir $lang_phone_dir \
- --ngram-G ./data/lm/G_3_gram.fst.txt
+ --lang-dir $lang_char_dir \
+ --ngram-G ./data/lm/G_3_gram_char.fst.txt
fi
fi
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
similarity index 99%
rename from egs/aishell/ASR/pruned_transducer_stateless7/train2.py
rename to egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
index 057af297f..6027273b2 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
@@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
index 2a9fc57d5..39d988cd0 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
@@ -56,7 +56,7 @@ import torch.nn as nn
from decoder2 import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 88eb34104..3c13c19c6 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1233,6 +1233,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
index 991875aaa..6c20bab2c 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
@@ -4,6 +4,6 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
-[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
-The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
+[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py).
+The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index c09c9537c..61a3f27db 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1237,6 +1237,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
CommonVoiceAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 4c866ddd8..acde72d80 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1274,6 +1274,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
CSJAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
index ebdb596a5..b210430c6 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -72,7 +72,7 @@ from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
-from train2 import add_model_arguments, get_params, get_transducer_model
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py
similarity index 99%
rename from egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
rename to egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py
index 420dc1065..d614f0914 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py
@@ -1099,6 +1099,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
index 85dbd4661..953f95c45 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
@@ -39,8 +39,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
index ab046557f..1e59e0858 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
@@ -61,7 +61,7 @@ import torch.nn as nn
from decoder import Decoder
from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
index d3691e647..0f3c63e75 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
@@ -4,7 +4,7 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
-[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
-The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
+[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py).
+The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index aa6c0668a..cd26db6f3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
index 07de57a86..a7d06a5dd 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
@@ -68,8 +68,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
index 9a6b31268..8f2178b1d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -66,8 +66,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py
new file mode 120000
index 000000000..beeffaa03
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py
index 9a6b31268..8f2178b1d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py
@@ -66,8 +66,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py
deleted file mode 120000
index 3c3280b68..000000000
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py
+++ /dev/null
@@ -1 +0,0 @@
-../pruned_transducer_stateless7_streaming/train2.py
\ No newline at end of file
From 11d816d174076ec9485ab8b1d36af2592514e348 Mon Sep 17 00:00:00 2001
From: Wei Kang
Date: Sat, 18 Nov 2023 18:47:55 +0800
Subject: [PATCH 005/123] Add cumstomized score for hotwords (#1385)
* add custom score for each hotword
* Add more comments
* Fix deocde
* fix style
* minor fixes
---
.../pruned_transducer_stateless7/decode.py | 2 +-
.../decode.py | 2 +-
.../pruned_transducer_stateless4/decode.py | 4 +-
egs/librispeech/ASR/zipformer/decode.py | 4 +-
.../pruned_transducer_stateless5/decode.py | 2 +-
icefall/context_graph.py | 117 +++++++++++++-----
6 files changed, 92 insertions(+), 39 deletions(-)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
index be58c4e43..696eea906 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
@@ -641,7 +641,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
- context_graph.build(contexts)
+ context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
index f5ae836fd..99110d6b6 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -686,7 +686,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
- context_graph.build(contexts)
+ context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index 524366068..5195a4ef6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -927,9 +927,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
- contexts.append(line.strip())
+ contexts.append((sp.encode(line.strip()), 0.0))
context_graph = ContextGraph(params.context_score)
- context_graph.build(sp.encode(contexts))
+ context_graph.build(contexts)
else:
context_graph = None
else:
diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py
index 3531d657f..339e253e6 100755
--- a/egs/librispeech/ASR/zipformer/decode.py
+++ b/egs/librispeech/ASR/zipformer/decode.py
@@ -1001,9 +1001,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
- contexts.append(line.strip())
+ contexts.append((sp.encode(line.strip()), 0.0))
context_graph = ContextGraph(params.context_score)
- context_graph.build(sp.encode(contexts))
+ context_graph.build(contexts)
else:
context_graph = None
else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 36b8a4b67..d665f3364 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -868,7 +868,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
- context_graph.build(contexts)
+ context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
diff --git a/icefall/context_graph.py b/icefall/context_graph.py
index 0b7c42c0b..b3d7972a8 100644
--- a/icefall/context_graph.py
+++ b/icefall/context_graph.py
@@ -84,6 +84,9 @@ class ContextGraph:
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
+ Note: This is just the default score for each token, the users can manually
+ specify the context_score for each word/phrase (i.e. different phrase might
+ have different token score).
"""
self.context_score = context_score
self.num_nodes = 0
@@ -133,7 +136,7 @@ class ContextGraph:
node.output_score += 0 if output is None else output.output_score
queue.append(node)
- def build(self, token_ids: List[List[int]]):
+ def build(self, token_ids: List[Tuple[List[int], float]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
@@ -142,26 +145,46 @@ class ContextGraph:
Args:
token_ids:
- The given token lists to build the ContextGraph, it is a list of token list,
- each token list contains the token ids for a word/phrase. The token id
- could be an id of a char (modeling with single Chinese char) or an id
- of a BPE (modeling with BPEs).
+ The given token lists to build the ContextGraph, it is a list of tuple of
+ token list and its customized score, the token list contains the token ids
+ for a word/phrase. The token id could be an id of a char
+ (modeling with single Chinese char) or an id of a BPE
+ (modeling with BPEs). The score is the total score for current token list,
+ 0 means using the default value (i.e. self.context_score).
+
+ Note: The phrases would have shared states, the score of the shared states is
+ the maximum value among all the tokens sharing this state.
"""
- for tokens in token_ids:
+ for (tokens, score) in token_ids:
node = self.root
+ # If has customized score using the customized token score, otherwise
+ # using the default score
+ context_score = (
+ self.context_score if score == 0.0 else round(score / len(tokens), 2)
+ )
for i, token in enumerate(tokens):
+ node_next = {}
if token not in node.next:
self.num_nodes += 1
+ node_id = self.num_nodes
+ token_score = context_score
is_end = i == len(tokens) - 1
- node_score = node.node_score + self.context_score
- node.next[token] = ContextState(
- id=self.num_nodes,
- token=token,
- token_score=self.context_score,
- node_score=node_score,
- output_score=node_score if is_end else 0,
- is_end=is_end,
- )
+ else:
+ # node exists, get the score of shared state.
+ token_score = max(context_score, node.next[token].token_score)
+ node_id = node.next[token].id
+ node_next = node.next[token].next
+ is_end = i == len(tokens) - 1 or node.next[token].is_end
+ node_score = node.node_score + token_score
+ node.next[token] = ContextState(
+ id=node_id,
+ token=token,
+ token_score=token_score,
+ node_score=node_score,
+ output_score=node_score if is_end else 0,
+ is_end=is_end,
+ )
+ node.next[token].next = node_next
node = node.next[token]
self._fill_fail_output()
@@ -343,7 +366,7 @@ class ContextGraph:
return dot
-if __name__ == "__main__":
+def _test(queries, score):
contexts_str = [
"S",
"HE",
@@ -355,9 +378,11 @@ if __name__ == "__main__":
"THIS",
"THEM",
]
+
+ # test default score (1)
contexts = []
for s in contexts_str:
- contexts.append([ord(x) for x in s])
+ contexts.append(([ord(x) for x in s], score))
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
@@ -369,10 +394,28 @@ if __name__ == "__main__":
context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
- filename="context_graph.pdf",
+ filename=f"context_graph_{score}.pdf",
symbol_table=symbol_table,
)
+ for query, expected_score in queries.items():
+ total_scores = 0
+ state = context_graph.root
+ for q in query:
+ score, state = context_graph.forward_one_step(state, ord(q))
+ total_scores += score
+ score, state = context_graph.finalize(state)
+ assert state.token == -1, state.token
+ total_scores += score
+ assert round(total_scores, 2) == expected_score, (
+ total_scores,
+ expected_score,
+ query,
+ )
+
+
+if __name__ == "__main__":
+ # test default score
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
@@ -384,17 +427,27 @@ if __name__ == "__main__":
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
- for query, expected_score in queries.items():
- total_scores = 0
- state = context_graph.root
- for q in query:
- score, state = context_graph.forward_one_step(state, ord(q))
- total_scores += score
- score, state = context_graph.finalize(state)
- assert state.token == -1, state.token
- total_scores += score
- assert total_scores == expected_score, (
- total_scores,
- expected_score,
- query,
- )
+ _test(queries, 0)
+
+ # test custom score (5)
+ # S : 5
+ # HE : 5 (2.5 + 2.5)
+ # SHE : 8.34 (5 + 1.67 + 1.67)
+ # SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
+ # HIS : 5.84 (2.5 + 1.67 + 1.67)
+ # HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
+ # HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
+ # THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
+ queries = {
+ "HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
+ "HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
+ "HISHE": 24.18, # "HIS", "S", "SHE", "HE"
+ "SHED": 18.34, # "S", "SHE", "HE"
+ "SHELF": 18.34, # "S", "SHE", "HE"
+ "HELL": 5, # "HE"
+ "HELLO": 13, # "HE", "HELLO"
+ "DHRHISQ": 10.84, # "HIS", "S"
+ "THEN": 5, # "HE"
+ }
+
+ _test(queries, 5)
From 238b45bea85deee1a07cfd0f55b485cc92f67135 Mon Sep 17 00:00:00 2001
From: Wei Kang
Date: Thu, 23 Nov 2023 01:22:57 +0800
Subject: [PATCH 006/123] Libriheavy recipe (zipformer) (#1261)
* initial commit for libriheavy
* Data prepare pipeline
* Fix train.py
* Fix decode.py
* Add results
* minor fixes
* black
* black
* Incorporate PR https://github.com/k2-fsa/icefall/pull/1269
---------
Co-authored-by: zr_jin
---
egs/libriheavy/ASR/README.md | 6 +
egs/libriheavy/ASR/RESULTS.md | 114 +-
.../ASR/local/compute_fbank_libriheavy.py | 242 +++
.../ASR/local/compute_fbank_musan.py | 1 +
egs/libriheavy/ASR/local/norm_text.py | 58 +
egs/libriheavy/ASR/local/prepare_manifest.py | 47 +
egs/libriheavy/ASR/local/train_bpe_model.py | 113 ++
egs/libriheavy/ASR/prepare.sh | 314 ++++
.../ASR/zipformer/asr_datamodule.py | 443 ++++++
egs/libriheavy/ASR/zipformer/beam_search.py | 1 +
egs/libriheavy/ASR/zipformer/decode.py | 794 +++++++++
egs/libriheavy/ASR/zipformer/decoder.py | 1 +
.../ASR/zipformer/encoder_interface.py | 1 +
egs/libriheavy/ASR/zipformer/export-onnx.py | 1 +
egs/libriheavy/ASR/zipformer/export.py | 1 +
.../ASR/zipformer/jit_pretrained.py | 1 +
egs/libriheavy/ASR/zipformer/joiner.py | 1 +
egs/libriheavy/ASR/zipformer/model.py | 1 +
egs/libriheavy/ASR/zipformer/onnx_decode.py | 1 +
.../ASR/zipformer/onnx_pretrained.py | 1 +
egs/libriheavy/ASR/zipformer/optim.py | 1 +
egs/libriheavy/ASR/zipformer/pretrained.py | 1 +
egs/libriheavy/ASR/zipformer/scaling.py | 1 +
.../ASR/zipformer/scaling_coverter.py | 1 +
egs/libriheavy/ASR/zipformer/subsampling.py | 1 +
.../ASR/zipformer/text_normalization.py | 50 +
egs/libriheavy/ASR/zipformer/train.py | 1415 +++++++++++++++++
egs/libriheavy/ASR/zipformer/zipformer.py | 1 +
requirements-ci.txt | 1 +
requirements.txt | 1 +
30 files changed, 3613 insertions(+), 2 deletions(-)
create mode 100644 egs/libriheavy/ASR/README.md
create mode 100755 egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
create mode 120000 egs/libriheavy/ASR/local/compute_fbank_musan.py
create mode 100755 egs/libriheavy/ASR/local/norm_text.py
create mode 100755 egs/libriheavy/ASR/local/prepare_manifest.py
create mode 100755 egs/libriheavy/ASR/local/train_bpe_model.py
create mode 100755 egs/libriheavy/ASR/prepare.sh
create mode 100644 egs/libriheavy/ASR/zipformer/asr_datamodule.py
create mode 120000 egs/libriheavy/ASR/zipformer/beam_search.py
create mode 100644 egs/libriheavy/ASR/zipformer/decode.py
create mode 120000 egs/libriheavy/ASR/zipformer/decoder.py
create mode 120000 egs/libriheavy/ASR/zipformer/encoder_interface.py
create mode 120000 egs/libriheavy/ASR/zipformer/export-onnx.py
create mode 120000 egs/libriheavy/ASR/zipformer/export.py
create mode 120000 egs/libriheavy/ASR/zipformer/jit_pretrained.py
create mode 120000 egs/libriheavy/ASR/zipformer/joiner.py
create mode 120000 egs/libriheavy/ASR/zipformer/model.py
create mode 120000 egs/libriheavy/ASR/zipformer/onnx_decode.py
create mode 120000 egs/libriheavy/ASR/zipformer/onnx_pretrained.py
create mode 120000 egs/libriheavy/ASR/zipformer/optim.py
create mode 120000 egs/libriheavy/ASR/zipformer/pretrained.py
create mode 120000 egs/libriheavy/ASR/zipformer/scaling.py
create mode 120000 egs/libriheavy/ASR/zipformer/scaling_coverter.py
create mode 120000 egs/libriheavy/ASR/zipformer/subsampling.py
create mode 100644 egs/libriheavy/ASR/zipformer/text_normalization.py
create mode 100644 egs/libriheavy/ASR/zipformer/train.py
create mode 120000 egs/libriheavy/ASR/zipformer/zipformer.py
diff --git a/egs/libriheavy/ASR/README.md b/egs/libriheavy/ASR/README.md
new file mode 100644
index 000000000..2498d017f
--- /dev/null
+++ b/egs/libriheavy/ASR/README.md
@@ -0,0 +1,6 @@
+# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context
+
+Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105).
+
+
+See [RESULTS](./RESULTS.md) for the results for icefall recipes.
diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md
index 4fbedad98..513bbf72e 100644
--- a/egs/libriheavy/ASR/RESULTS.md
+++ b/egs/libriheavy/ASR/RESULTS.md
@@ -1,6 +1,116 @@
-## Results
+# Results
-### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder)
+## zipformer (zipformer + pruned stateless transducer)
+
+See for more details.
+
+[zipformer](./zipformer)
+
+### Non-streaming
+
+#### Training on normalized text, i.e. Upper case without punctuation
+
+##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M
+
+You can find a pretrained model, training logs at:
+
+
+Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set),
+exp_small_subset(small set).
+
+Results of models:
+
+| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment |
+|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------|
+| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 |
+| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 |
+| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 |
+| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 |
+| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 |
+| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 |
+
+The training command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+python ./zipformer/train.py \
+ --world-size 4 \
+ --master-port 12365 \
+ --exp-dir zipformer/exp \
+ --num-epochs 60 \ # 16 for large; 90 for small
+ --lr-hours 15000 \ # 20000 for large; 5000 for small
+ --use-fp16 1 \
+ --start-epoch 1 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --max-duration 1000 \
+ --subset medium
+```
+
+The decoding command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+for m in greedy_search modified_beam_search; do
+ ./zipformer/decode.py \
+ --epoch 16 \
+ --avg 3 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000 \
+ --causal 0 \
+ --decoding-method $m
+done
+```
+
+#### Training on full formatted text, i.e. with casing and punctuation
+
+##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M
+
+You can find a pretrained model, training logs at:
+
+
+Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set),
+exp_small_subset(small set).
+
+Results of models:
+
+| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment |
+|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------|
+| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 |
+| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 |
+| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 |
+
+The training command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+python ./zipformer/train.py \
+ --world-size 4 \
+ --master-port 12365 \
+ --exp-dir zipformer/exp \
+ --num-epochs 60 \ # 16 for large; 90 for small
+ --lr-hours 15000 \ # 20000 for large; 10000 for small
+ --use-fp16 1 \
+ --train-with-punctuation 1 \
+ --start-epoch 1 \
+ --bpe-model data/lang_punc_bpe_756/bpe.model \
+ --max-duration 1000 \
+ --subset medium
+```
+
+The decoding command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+for m in greedy_search modified_beam_search; do
+ ./zipformer/decode.py \
+ --epoch 16 \
+ --avg 3 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000 \
+ --causal 0 \
+ --decoding-method $m
+done
+```
+
+## Zipformer PromptASR (zipformer + PromptASR + BERT text encoder)
#### [zipformer_prompt_asr](./zipformer_prompt_asr)
diff --git a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
new file mode 100755
index 000000000..010531db2
--- /dev/null
+++ b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the Libriheavy dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ LilcomChunkyWriter,
+)
+
+from icefall.utils import get_executor, str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-dir",
+ type=str,
+ help="""The source directory that contains raw manifests.
+ """,
+ default="data/manifests",
+ )
+
+ parser.add_argument(
+ "--fbank-dir",
+ type=str,
+ help="""Fbank output dir
+ """,
+ default="data/fbank",
+ )
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=20,
+ help="Number of dataloading workers used for reading the audio.",
+ )
+
+ parser.add_argument(
+ "--batch-duration",
+ type=float,
+ default=600.0,
+ help="The maximum number of audio seconds in a batch."
+ "Determines batch size dynamically.",
+ )
+
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=False,
+ help="Whether to use speed perturbation.",
+ )
+
+ parser.add_argument(
+ "--use-splits",
+ type=str2bool,
+ default=False,
+ help="Whether to compute fbank on splits.",
+ )
+
+ parser.add_argument(
+ "--num-splits",
+ type=int,
+ help="""The number of splits of the medium and large subset.
+ Only needed when --use-splits is true.""",
+ )
+
+ parser.add_argument(
+ "--start",
+ type=int,
+ default=0,
+ help="""Process pieces starting from this number (inclusive).
+ Only needed when --use-splits is true.""",
+ )
+
+ parser.add_argument(
+ "--stop",
+ type=int,
+ default=-1,
+ help="""Stop processing pieces until this number (exclusive).
+ Only needed when --use-splits is true.""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_libriheavy(args):
+ src_dir = Path(args.manifest_dir)
+ output_dir = Path(args.fbank_dir)
+ num_jobs = min(15, os.cpu_count())
+ num_mel_bins = 80
+ subset = args.subset
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz"
+ if output_cuts_path.exists():
+ logging.info(f"{output_cuts_path} exists - skipping")
+ return
+
+ input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz"
+ assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!"
+ logging.info(f"Loading {input_cuts_path}")
+ cut_set = CutSet.from_file(input_cuts_path)
+
+ logging.info("Computing features")
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/libriheavy_feats_{subset}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ logging.info(f"Saving to {output_cuts_path}")
+ cut_set.to_file(output_cuts_path)
+
+
+def compute_fbank_libriheavy_splits(args):
+ num_splits = args.num_splits
+ subset = args.subset
+ src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split"
+ src_dir = Path(src_dir)
+ output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split"
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ start = args.start
+ stop = args.stop
+ if stop < start:
+ stop = num_splits
+
+ stop = min(stop, num_splits)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
+ logging.info(f"device: {device}")
+
+ num_digits = 8 # num_digits is fixed by lhotse split-lazy
+ for i in range(start, stop):
+ idx = f"{i + 1}".zfill(num_digits)
+ logging.info(f"Processing {idx}/{num_splits}")
+
+ cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz"
+ if cuts_path.is_file():
+ logging.info(f"{cuts_path} exists - skipping")
+ continue
+
+ raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz"
+ if not raw_cuts_path.is_file():
+ logging.info(f"{raw_cuts_path} does not exist - skipping it")
+ continue
+
+ logging.info(f"Loading {raw_cuts_path}")
+ cut_set = CutSet.from_file(raw_cuts_path)
+
+ logging.info("Computing features")
+ if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists():
+ logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca")
+ os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca")
+
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}",
+ num_workers=args.num_workers,
+ batch_duration=args.batch_duration,
+ overwrite=True,
+ )
+
+ logging.info("About to split cuts into smaller chunks.")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ logging.info(f"Saving to {cuts_path}")
+ cut_set.to_file(cuts_path)
+ logging.info(f"Saved to {cuts_path}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+
+ if args.use_splits:
+ assert args.num_splits is not None, "Please provide num_splits"
+ compute_fbank_libriheavy_splits(args)
+ else:
+ compute_fbank_libriheavy(args)
diff --git a/egs/libriheavy/ASR/local/compute_fbank_musan.py b/egs/libriheavy/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/libriheavy/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py
new file mode 100755
index 000000000..c2fc0d92d
--- /dev/null
+++ b/egs/libriheavy/ASR/local/norm_text.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import codecs
+import sys
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="""Path to the input text.
+ """,
+ )
+ return parser.parse_args()
+
+
+def remove_punc_to_upper(text: str) -> str:
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
+ s_list = [x.upper() if x in tokens else " " for x in text]
+ s = " ".join("".join(s_list).split()).strip()
+ return s
+
+
+def main():
+ args = get_args()
+ if args.text:
+ f = codecs.open(args.text, encoding="utf-8")
+ else:
+ f = codecs.getreader("utf-8")(sys.stdin.buffer)
+
+ sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
+ line = f.readline()
+ while line:
+ print(remove_punc_to_upper(line))
+ line = f.readline()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py
new file mode 100755
index 000000000..42f392cae
--- /dev/null
+++ b/egs/libriheavy/ASR/local/prepare_manifest.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gzip
+import json
+import sys
+from pathlib import Path
+
+
+def simple_cleanup(text: str) -> str:
+ table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
+ text = text.translate(table)
+ return text.strip()
+
+
+# Assign text of the supervisions and remove unnecessary entries.
+def main():
+ assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
+ fname = Path(sys.argv[1]).name
+ oname = Path(sys.argv[2]) / fname
+ with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
+ for line in fin:
+ cut = json.loads(line)
+ cut["supervisions"][0]["text"] = simple_cleanup(
+ cut["supervisions"][0]["custom"]["texts"][0]
+ )
+ del cut["supervisions"][0]["custom"]
+ del cut["custom"]
+ fout.write((json.dumps(cut) + "\n").encode())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py
new file mode 100755
index 000000000..19caf43ab
--- /dev/null
+++ b/egs/libriheavy/ASR/local/train_bpe_model.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# You can install sentencepiece via:
+#
+# pip install sentencepiece
+#
+# Due to an issue reported in
+# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
+#
+# Please install a version >=0.1.96
+
+import argparse
+import shutil
+from pathlib import Path
+
+import sentencepiece as spm
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ The generated bpe.model is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--byte-fallback",
+ action="store_true",
+ help="""Whether to enable byte_fallback when training bpe.""",
+ )
+
+ parser.add_argument(
+ "--character-coverage",
+ type=float,
+ default=1.0,
+ help="Character coverage in vocabulary.",
+ )
+
+ parser.add_argument(
+ "--transcript",
+ type=str,
+ help="Training transcript.",
+ )
+
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ help="Vocabulary size for BPE training",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ vocab_size = args.vocab_size
+ lang_dir = Path(args.lang_dir)
+
+ model_type = "unigram"
+
+ model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
+ train_text = args.transcript
+ input_sentence_size = 100000000
+
+ user_defined_symbols = ["", ""]
+ unk_id = len(user_defined_symbols)
+ # Note: unk_id is fixed to 2.
+ # If you change it, you should also change other
+ # places that are using it.
+
+ model_file = Path(model_prefix + ".model")
+ if not model_file.is_file():
+ spm.SentencePieceTrainer.train(
+ input=train_text,
+ vocab_size=vocab_size,
+ model_type=model_type,
+ model_prefix=model_prefix,
+ input_sentence_size=input_sentence_size,
+ character_coverage=args.character_coverage,
+ user_defined_symbols=user_defined_symbols,
+ byte_fallback=args.byte_fallback,
+ unk_id=unk_id,
+ bos_id=-1,
+ eos_id=-1,
+ )
+ else:
+ print(f"{model_file} exists - skipping")
+ return
+
+ shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh
new file mode 100755
index 000000000..af7e3c5b0
--- /dev/null
+++ b/egs/libriheavy/ASR/prepare.sh
@@ -0,0 +1,314 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=15
+stage=-1
+stop_stage=100
+export CUDA_VISIBLE_DEVICES=""
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/librilight
+# You can find small, medium, large, etc. inside it.
+#
+# - $dl_dir/libriheavy
+# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it.
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ # 5000
+ # 2000
+ # 1000
+ 500
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+fbank_dir=data/fbank
+manifests_dir=data/manifests
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: Download audio data."
+ # If you have pre-downloaded it to /path/to/librilight,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/librilight $dl_dir/librilight
+ #
+ mkdir -p $dl_dir/librilight
+ for subset in small medium large; do
+ log "Downloading ${subset} subset."
+ if [ ! -d $dl_dir/librilight/${subset} ]; then
+ wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar
+ tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight
+ else
+ log "Skipping download, ${subset} subset exists."
+ fi
+ done
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download manifests from huggingface."
+
+ # If you have pre-downloaded it to /path/to/libriheavy,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/libriheavy $dl_dir/libriheavy
+ #
+ mkdir -p $dl_dir/libriheavy
+ for subset in small medium large dev test_clean test_other; do
+ if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then
+ log "Downloading ${subset} subset."
+ wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz
+ else
+ log "Skipping download, ${subset} subset exists."
+ fi
+ done
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Download manifests from modelscope"
+ mkdir -p $dl_dir/libriheavy
+ if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_small.jsonl.gz ]; then
+ cd $dl_dir/libriheavy
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/datasets/pkufool/Libriheavy.git
+ cd Libriheavy
+ git lfs pull --exclude "raw/*"
+ mv *.jsonl.gz ../
+ cd ..
+ rm -rf Libriheavy
+ cd ../../
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ mkdir -p $manifests_dir
+ if [ ! -e $manifests_dir/.musan.done ]; then
+ lhotse prepare musan $dl_dir/musan $manifests_dir
+ touch $manifests_dir/.musan.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare Libriheavy manifests"
+ mkdir -p $manifests_dir
+ for subset in small medium large dev test_clean test_other; do
+ if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
+ log "Prepare manifest for subset : ${subset}"
+ ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir
+ fi
+ done
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ mkdir -p $fbank_dir
+ if [ ! -e $fbank_dir/.musan.done ]; then
+ ./local/compute_fbank_musan.py
+ touch $fbank_dir/.musan.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for small subset and validation subsets"
+ for subset in test_clean test_other dev small; do
+ log "Computing $subset subset."
+ if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then
+ ./local/compute_fbank_libriheavy.py \
+ --manifest-dir ${manifests_dir} \
+ --subset ${subset} \
+ --fbank-dir $fbank_dir \
+ --num-workers $nj
+ fi
+ done
+fi
+
+num_per_split=8000
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Split medium and large subsets."
+ for subset in medium large; do
+ log "Spliting subset : $subset"
+ split_dir=$manifests_dir/libriheavy_${subset}_split
+ mkdir -p $split_dir
+ if [ ! -e $split_dir/.split_completed ]; then
+ lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split
+ touch $split_dir/.split_completed
+ fi
+ done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compute fbank for medium and large subsets"
+ mkdir -p $fbank_dir
+ chunk_size=20
+ for subset in medium large; do
+ if [ $subset == "large" ]; then
+ chunk_size=200
+ fi
+ num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l)
+ if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then
+ for i in $(seq 0 1 6); do
+ start=$(( i * $chunk_size ))
+ end=$(( (i+1) * $chunk_size ))
+ ./local/compute_fbank_libriheavy.py \
+ --manifest-dir ${manifests_dir} \
+ --use-splits 1 \
+ --subset ${subset} \
+ --fbank-dir $fbank_dir \
+ --num-splits $num_splits \
+ --num-workers $nj \
+ --start $start \
+ --stop $end &
+ done
+ wait
+ touch $fbank_dir/.libriheavy.${subset}.done
+ fi
+ done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Combine features for medium and large subsets."
+ for subset in medium large; do
+ log "Combining $subset subset."
+ if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
+ pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz")
+ lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz
+ fi
+ done
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Train BPE model for normalized text"
+
+ if [ ! -f data/texts ]; then
+ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
+ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
+ | ./local/norm_text.py > data/texts
+ fi
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ cp data/texts $lang_dir/text
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/text
+ fi
+ done
+fi
+
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "Stage 10: Train BPE model for unnormalized text"
+ if [ ! -f data/punc_texts ]; then
+ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
+ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
+ fi
+ for vocab_size in ${vocab_sizes[@]}; do
+ new_vacab_size = $(($vocab_size + 256))
+ lang_dir=data/lang_punc_bpe_${new_vocab_size}
+ mkdir -p $lang_dir
+
+ cp data/punc_texts $lang_dir/text
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --byte-fallback \
+ --vocab-size ${new_vocab_size} \
+ --byte-fallback \
+ --character-coverage 0.99 \
+ --transcript $lang_dir/text
+ fi
+ done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Prepare language model for normalized text"
+
+ for subset in small medium large; do
+ if [ ! -f $manifests_dir/texts_${subset} ]; then
+ gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \
+ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
+ | ./local/norm_text.py > $manifests_dir/texts_${subset}
+ fi
+ done
+
+ mkdir -p data/lm
+ if [ ! -f data/lm/text ]; then
+ cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text
+ fi
+
+ (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \
+ > data/lm/words.txt
+
+ cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
+ | awk '{print $1" "NR+3}' >> data/lm/words.txt
+
+ num_lines=$(< data/lm/words.txt wc -l)
+ (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \
+ >> data/lm/words.txt
+
+ # Train LM on transcripts
+ if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
+ python3 ./shared/make_kn_lm.py \
+ -ngram-order 3 \
+ -text data/lm/text \
+ -lm data/lm/3-gram.unpruned.arpa
+ fi
+
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+ if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table=data/lm/words.txt \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
+ fi
+fi
+
diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
new file mode 100644
index 000000000..df761c1b8
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1,443 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriHeavyAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--subset",
+ type=str,
+ default="S",
+ help="""The subset to be used. Should be S, M or L. Note: S subset
+ includes libriheavy_cuts_small.jsonl.gz, M subset includes
+ libriheavy_cuts_small.jsonl.gz and libriheavy_cuts_medium.jsonl.gz,
+ L subset includes libriheavy_cuts_small.jsonl.gz,
+ libriheavy_cuts_medium.jsonl.gz and libriheavy_cuts_large.jsonl.gz.
+ """,
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_small_cuts(self) -> CutSet:
+ logging.info("About to get small subset cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_medium_cuts(self) -> CutSet:
+ logging.info("About to get medium subset cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_large_cuts(self) -> CutSet:
+ logging.info("About to get large subset cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get the test-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get the test-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz"
+ )
diff --git a/egs/libriheavy/ASR/zipformer/beam_search.py b/egs/libriheavy/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py
new file mode 100644
index 000000000..1928e2635
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/decode.py
@@ -0,0 +1,794 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+import warnings
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriHeavyAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from text_normalization import remove_punc_to_upper
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--train-with-punctuation",
+ type=str2bool,
+ default=False,
+ help="""Set to True, if the model was trained on texts with casing
+ and punctuation.""",
+ )
+
+ parser.add_argument(
+ "--post-normalization",
+ type=str2bool,
+ default=False,
+ help="""Upper case and remove all chars except ' and -
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ this_batch = []
+ if params.post_normalization and params.train_with_punctuation:
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = remove_punc_to_upper(ref_text).split()
+ hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[f"{name}_norm"].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriHeavyAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ libriheavy = LibriHeavyAsrDataModule(args)
+
+ def normalize_text(c: Cut):
+ text = remove_punc_to_upper(c.supervisions[0].text)
+ c.supervisions[0].text = text
+ return c
+
+ test_clean_cuts = libriheavy.test_clean_cuts()
+ test_other_cuts = libriheavy.test_other_cuts()
+
+ if not params.train_with_punctuation:
+ test_clean_cuts = test_clean_cuts.map(normalize_text)
+ test_other_cuts = test_other_cuts.map(normalize_text)
+
+ test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
+ test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/zipformer/decoder.py b/egs/libriheavy/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/encoder_interface.py b/egs/libriheavy/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..c2eaca671
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/export-onnx.py b/egs/libriheavy/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/export.py b/egs/libriheavy/ASR/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/jit_pretrained.py b/egs/libriheavy/ASR/zipformer/jit_pretrained.py
new file mode 120000
index 000000000..25108391f
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/joiner.py b/egs/libriheavy/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/model.py b/egs/libriheavy/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/onnx_decode.py b/egs/libriheavy/ASR/zipformer/onnx_decode.py
new file mode 120000
index 000000000..0573b88c5
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/onnx_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_decode.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/optim.py b/egs/libriheavy/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/pretrained.py b/egs/libriheavy/ASR/zipformer/pretrained.py
new file mode 120000
index 000000000..0bd71dde4
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/pretrained.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/scaling.py b/egs/libriheavy/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/scaling_coverter.py b/egs/libriheavy/ASR/zipformer/scaling_coverter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/scaling_coverter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/subsampling.py b/egs/libriheavy/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py
new file mode 100644
index 000000000..92590769c
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/text_normalization.py
@@ -0,0 +1,50 @@
+from num2words import num2words
+
+
+def remove_punc_to_upper(text: str) -> str:
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
+ s_list = [x.upper() if x in tokens else " " for x in text]
+ s = " ".join("".join(s_list).split()).strip()
+ return s
+
+
+def word_normalization(word: str) -> str:
+ # 1. Use full word for some abbreviation
+ # 2. Convert digits to english words
+ # 3. Convert ordinal number to english words
+ if word == "MRS":
+ return "MISSUS"
+ if word == "MR":
+ return "MISTER"
+ if word == "ST":
+ return "SAINT"
+ if word == "ECT":
+ return "ET CETERA"
+
+ if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH
+ word = num2words(word[:-2], to="ordinal")
+ word = word.replace("-", " ")
+
+ if word.isnumeric():
+ num = int(word)
+ if num > 1500 and num < 2030:
+ word = num2words(word, to="year")
+ else:
+ word = num2words(word)
+ word = word.replace("-", " ")
+ return word.upper()
+
+
+def text_normalization(text: str) -> str:
+ text = text.upper()
+ return " ".join([word_normalization(x) for x in text.split()])
+
+
+if __name__ == "__main__":
+ assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
+ assert (
+ text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
+ == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER"
+ )
diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py
new file mode 100644
index 000000000..c97da4a11
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/train.py
@@ -0,0 +1,1415 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --full-libri 1 \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --full-libri 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriHeavyAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from text_normalization import remove_punc_to_upper
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-hours",
+ type=float,
+ default=30000,
+ help="""Number of hours that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--train-with-punctuation",
+ type=str2bool,
+ default=False,
+ help="If True, the training text will include casing and punctuation.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ # Use the number of hours of speech to adjust the learning rate
+ scheduler.step_epoch(
+ params.batch_idx_train * params.max_duration * params.world_size / 3600
+ )
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def normalize_text(c: Cut):
+ text = remove_punc_to_upper(c.supervisions[0].text)
+ c.supervisions[0].text = text
+ return c
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 2.0 or c.duration > 30.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ libriheavy = LibriHeavyAsrDataModule(args)
+
+ train_cuts = libriheavy.train_small_cuts()
+ if params.subset == "M" or params.subset == "L":
+ train_cuts += libriheavy.train_medium_cuts()
+ if params.subset == "L":
+ train_cuts += libriheavy.train_large_cuts()
+
+ if not params.train_with_punctuation:
+ train_cuts = train_cuts.map(normalize_text)
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = libriheavy.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = libriheavy.dev_cuts()
+
+ if not params.train_with_punctuation:
+ valid_cuts = valid_cuts.map(normalize_text)
+
+ valid_dl = libriheavy.valid_dataloaders(valid_cuts)
+
+ # if not params.print_diagnostics:
+ # scan_pessimistic_batches_for_oom(
+ # model=model,
+ # train_dl=train_dl,
+ # optimizer=optimizer,
+ # sp=sp,
+ # params=params,
+ # )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriHeavyAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/zipformer/zipformer.py b/egs/libriheavy/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/requirements-ci.txt b/requirements-ci.txt
index e1232a768..6c74f688c 100644
--- a/requirements-ci.txt
+++ b/requirements-ci.txt
@@ -17,6 +17,7 @@ six
git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11
kaldialign==0.7.1
+num2words
sentencepiece==0.1.96
tensorboard==2.8.0
typeguard==2.13.3
diff --git a/requirements.txt b/requirements.txt
index 5a8326619..9502fcbd2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
kaldifst
kaldilm
kaldialign
+num2words
kaldi-decoder
sentencepiece>=0.1.96
tensorboard
From ae67f75e9c429d35e8a84d6d70cc8050eae37c86 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Sun, 26 Nov 2023 10:04:15 +0800
Subject: [PATCH 007/123] a bilingual recipe similar to the `multi-zh_hans`
(#1265)
---
...rmer.sh => run-multi-corpora-zipformer.sh} | 38 +
...er.yml => run-multi-corpora-zipformer.yml} | 10 +-
egs/multi_zh_en/ASR/README.md | 19 +
egs/multi_zh_en/ASR/RESULTS.md | 44 +
egs/multi_zh_en/ASR/local/compile_lg.py | 1 +
egs/multi_zh_en/ASR/local/prepare_char.py | 1 +
.../ASR/local/prepare_for_bpe_model.py | 65 +
egs/multi_zh_en/ASR/local/prepare_lang.py | 1 +
.../ASR/local/prepare_lang_bbpe.py | 1 +
egs/multi_zh_en/ASR/local/prepare_lang_bpe.py | 1 +
egs/multi_zh_en/ASR/local/prepare_words.py | 1 +
egs/multi_zh_en/ASR/local/text2segments.py | 1 +
egs/multi_zh_en/ASR/local/text2token.py | 1 +
egs/multi_zh_en/ASR/local/train_bbpe_model.py | 1 +
.../ASR/local/validate_bpe_lexicon.py | 1 +
egs/multi_zh_en/ASR/prepare.sh | 149 ++
egs/multi_zh_en/ASR/shared | 1 +
.../ASR/zipformer/asr_datamodule.py | 385 +++++
egs/multi_zh_en/ASR/zipformer/beam_search.py | 1 +
egs/multi_zh_en/ASR/zipformer/decode.py | 851 ++++++++++
egs/multi_zh_en/ASR/zipformer/decoder.py | 1 +
.../ASR/zipformer/encoder_interface.py | 1 +
.../ASR/zipformer/export-onnx-streaming.py | 1 +
egs/multi_zh_en/ASR/zipformer/export-onnx.py | 1 +
egs/multi_zh_en/ASR/zipformer/export.py | 541 +++++++
.../ASR/zipformer/generate_averaged_model.py | 193 +++
.../ASR/zipformer/jit_pretrained.py | 1 +
.../ASR/zipformer/jit_pretrained_ctc.py | 1 +
.../ASR/zipformer/jit_pretrained_streaming.py | 1 +
egs/multi_zh_en/ASR/zipformer/joiner.py | 1 +
egs/multi_zh_en/ASR/zipformer/model.py | 1 +
.../ASR/zipformer/multi_dataset.py | 247 +++
egs/multi_zh_en/ASR/zipformer/onnx_check.py | 1 +
egs/multi_zh_en/ASR/zipformer/onnx_decode.py | 1 +
.../zipformer/onnx_pretrained-streaming.py | 1 +
.../ASR/zipformer/onnx_pretrained.py | 1 +
egs/multi_zh_en/ASR/zipformer/optim.py | 1 +
egs/multi_zh_en/ASR/zipformer/pretrained.py | 378 +++++
egs/multi_zh_en/ASR/zipformer/scaling.py | 1 +
.../ASR/zipformer/scaling_converter.py | 1 +
.../ASR/zipformer/streaming_beam_search.py | 1 +
.../ASR/zipformer/streaming_decode.py | 1 +
egs/multi_zh_en/ASR/zipformer/subsampling.py | 1 +
egs/multi_zh_en/ASR/zipformer/train.py | 1416 +++++++++++++++++
egs/multi_zh_en/ASR/zipformer/zipformer.py | 1 +
45 files changed, 4363 insertions(+), 5 deletions(-)
rename .github/scripts/{run-multi-zh_hans-zipformer.sh => run-multi-corpora-zipformer.sh} (66%)
rename .github/workflows/{run-multi-zh_hans-zipformer.yml => run-multi-corpora-zipformer.yml} (91%)
create mode 100644 egs/multi_zh_en/ASR/README.md
create mode 100644 egs/multi_zh_en/ASR/RESULTS.md
create mode 120000 egs/multi_zh_en/ASR/local/compile_lg.py
create mode 120000 egs/multi_zh_en/ASR/local/prepare_char.py
create mode 100755 egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py
create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang.py
create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py
create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bpe.py
create mode 120000 egs/multi_zh_en/ASR/local/prepare_words.py
create mode 120000 egs/multi_zh_en/ASR/local/text2segments.py
create mode 120000 egs/multi_zh_en/ASR/local/text2token.py
create mode 120000 egs/multi_zh_en/ASR/local/train_bbpe_model.py
create mode 120000 egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py
create mode 100755 egs/multi_zh_en/ASR/prepare.sh
create mode 120000 egs/multi_zh_en/ASR/shared
create mode 100644 egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/beam_search.py
create mode 100755 egs/multi_zh_en/ASR/zipformer/decode.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/decoder.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/encoder_interface.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx.py
create mode 100755 egs/multi_zh_en/ASR/zipformer/export.py
create mode 100755 egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/joiner.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/model.py
create mode 100644 egs/multi_zh_en/ASR/zipformer/multi_dataset.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_check.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_decode.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/optim.py
create mode 100755 egs/multi_zh_en/ASR/zipformer/pretrained.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling_converter.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_decode.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/subsampling.py
create mode 100755 egs/multi_zh_en/ASR/zipformer/train.py
create mode 120000 egs/multi_zh_en/ASR/zipformer/zipformer.py
diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh
similarity index 66%
rename from .github/scripts/run-multi-zh_hans-zipformer.sh
rename to .github/scripts/run-multi-corpora-zipformer.sh
index cbd86a4d3..90f859f43 100755
--- a/.github/scripts/run-multi-zh_hans-zipformer.sh
+++ b/.github/scripts/run-multi-corpora-zipformer.sh
@@ -95,3 +95,41 @@ for method in modified_beam_search fast_beam_search; do
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
done
+
+rm -rf $repo
+
+cd ../../../egs/multi_zh_en/ASR
+log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ===="
+repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+./zipformer/pretrained.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \
+ --method greedy_search \
+$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \
+$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \
+$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav
+
+for method in modified_beam_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \
+ $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \
+ $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \
+ $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav
+done
+
+rm -rf $repo
diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml
similarity index 91%
rename from .github/workflows/run-multi-zh_hans-zipformer.yml
rename to .github/workflows/run-multi-corpora-zipformer.yml
index 72c0775a7..38f7eb908 100644
--- a/.github/workflows/run-multi-zh_hans-zipformer.yml
+++ b/.github/workflows/run-multi-corpora-zipformer.yml
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-name: run-multi-zh_hans-zipformer
+name: run-multi-corpora-zipformer
on:
push:
@@ -24,12 +24,12 @@ on:
types: [labeled]
concurrency:
- group: run_multi-zh_hans_zipformer-${{ github.ref }}
+ group: run_multi-corpora_zipformer-${{ github.ref }}
cancel-in-progress: true
jobs:
- run_multi-zh_hans_zipformer:
- if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer'
+ run_multi-corpora_zipformer:
+ if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' || github.event.label.name == 'multi-corpora'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -81,4 +81,4 @@ jobs:
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-multi-zh_hans-zipformer.sh
+ .github/scripts/run-multi-corpora-zipformer.sh
diff --git a/egs/multi_zh_en/ASR/README.md b/egs/multi_zh_en/ASR/README.md
new file mode 100644
index 000000000..29341571d
--- /dev/null
+++ b/egs/multi_zh_en/ASR/README.md
@@ -0,0 +1,19 @@
+# Introduction
+
+This recipe includes scripts for training Zipformer model using both English and Chinese datasets.
+
+# Included Training Sets
+
+1. LibriSpeech (English)
+2. AiShell-2 (Chinese)
+3. TAL-CSASR (Code-Switching, Chinese and English)
+
+|Datset| Number of hours| URL|
+|---|---:|---|
+|**TOTAL**|2,547|---|
+|LibriSpeech|960|https://www.openslr.org/12/|
+|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
+|TAL-CSASR|587|https://ai.100tal.com/openData/voice|
+
+
+
diff --git a/egs/multi_zh_en/ASR/RESULTS.md b/egs/multi_zh_en/ASR/RESULTS.md
new file mode 100644
index 000000000..3562d6ac3
--- /dev/null
+++ b/egs/multi_zh_en/ASR/RESULTS.md
@@ -0,0 +1,44 @@
+## Results
+
+### Zh-En datasets bpe-based training results (Non-streaming) on Zipformer model
+
+This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1265) in icefall.
+
+#### Non-streaming (Byte-Level BPE vocab_size=2000)
+
+Best results (num of params : ~69M):
+
+The training command:
+
+```
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 35 \
+ --use-fp16 1 \
+ --max-duration 1000 \
+ --num-workers 8
+```
+
+The decoding command:
+
+```
+for method in greedy_search modified_beam_search fast_beam_search; do
+ ./zipformer/decode.py \
+ --epoch 34 \
+ --avg 19 \
+ --decoding-method $method
+done
+```
+
+Word Error Rates (WERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model (# tokens is 2000).
+
+| Datasets | TAL-CSASR | TAL-CSASR | AiShell-2 | AiShell-2 | LibriSpeech | LibriSpeech |
+|----------------------|-----------|-----------|-----------|-----------|-------------|-------------|
+| Zipformer WER (%) | dev | test | dev | test | test-clean | test-other |
+| greedy_search | 6.65 | 6.69 | 6.57 | 7.03 | 2.43 | 5.70 |
+| modified_beam_search | 6.46 | 6.51 | 6.18 | 6.60 | 2.41 | 5.57 |
+| fast_beam_search | 6.57 | 6.68 | 6.40 | 6.74 | 2.40 | 5.56 |
+
+Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22, which is trained on LibriSpeech 960-hour training set (with speed perturbation), TAL-CSASR training set (with speed perturbation) and AiShell-2 (w/o speed perturbation).
+
+
diff --git a/egs/multi_zh_en/ASR/local/compile_lg.py b/egs/multi_zh_en/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_char.py b/egs/multi_zh_en/ASR/local/prepare_char.py
new file mode 120000
index 000000000..42743b544
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_char.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py
new file mode 100755
index 000000000..00514e6bb
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script tokenizes the training transcript by CJK characters
+# and saves the result to transcript_chars.txt, which is used
+# to train the BPE model later.
+
+import argparse
+from pathlib import Path
+
+from tqdm.auto import tqdm
+
+from icefall.utils import tokenize_by_CJK_char
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Output directory.
+ The generated transcript_chars.txt is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="Training transcript.",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ text = Path(args.text)
+
+ assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!"
+
+ transcript_path = lang_dir / "transcript_chars.txt"
+
+ with open(text, "r", encoding="utf-8") as fin:
+ with open(transcript_path, "w+", encoding="utf-8") as fout:
+ for line in tqdm(fin):
+ fout.write(tokenize_by_CJK_char(line) + "\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/local/prepare_lang.py b/egs/multi_zh_en/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py
new file mode 120000
index 000000000..9a0b44642
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_lang_bbpe.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_words.py b/egs/multi_zh_en/ASR/local/prepare_words.py
new file mode 120000
index 000000000..ef2b4eaf3
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_words.py
@@ -0,0 +1 @@
+../../../aishell2/ASR/local/prepare_words.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/text2segments.py b/egs/multi_zh_en/ASR/local/text2segments.py
new file mode 120000
index 000000000..7d68a39c3
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/text2segments.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/text2segments.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/text2token.py b/egs/multi_zh_en/ASR/local/text2token.py
new file mode 120000
index 000000000..ce5cfd537
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/text2token.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/train_bbpe_model.py b/egs/multi_zh_en/ASR/local/train_bbpe_model.py
new file mode 120000
index 000000000..7fb4a9f9d
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/train_bbpe_model.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/train_bbpe_model.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh
new file mode 100755
index 000000000..9f2be5a5c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/prepare.sh
@@ -0,0 +1,149 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+vocab_sizes=(
+ 2000
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+log "Dataset: musan"
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Soft link fbank of musan"
+ mkdir -p data/fbank
+ if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then
+ cd data/fbank
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) .
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) .
+ cd ../..
+ else
+ log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4"
+ exit 1
+ fi
+fi
+
+log "Dataset: LibriSpeech"
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Soft link fbank of LibriSpeech"
+ mkdir -p data/fbank
+ if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
+ cd data/fbank
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) .
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) .
+ cd ../..
+ else
+ log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"
+ exit 1
+ fi
+fi
+
+log "Dataset: AiShell-2"
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Soft link fbank of AiShell-2"
+ mkdir -p data/fbank
+ if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then
+ cd data/fbank
+ ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts*) .
+ ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats*) .
+ cd ../..
+ else
+ log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
+ exit 1
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare Byte BPE based lang"
+ mkdir -p data/fbank
+ if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
+ log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
+ exit 1
+ fi
+
+ if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
+ log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
+ exit 1
+ fi
+
+ cd data/
+ if [ ! -d ./lang_char ]; then
+ ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
+ fi
+ if [ ! -d ./lang_bpe_500 ]; then
+ ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
+ fi
+ cd ../
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bbpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
+ > $lang_dir/text
+
+ if [ ! -f $lang_dir/transcript_chars.txt ]; then
+ ./local/prepare_for_bpe_model.py \
+ --lang-dir ./$lang_dir \
+ --text $lang_dir/text
+ fi
+
+ if [ ! -f $lang_dir/text_words_segmentation ]; then
+ python3 ./local/text2segments.py \
+ --input-file ./data/lang_char/text \
+ --output-file $lang_dir/text_words_segmentation
+
+ cat ./data/lang_bpe_500/transcript_words.txt \
+ >> $lang_dir/text_words_segmentation
+
+ cat ./data/lang_char/text \
+ >> $lang_dir/text
+ fi
+
+ cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt
+
+ if [ ! -f $lang_dir/words.txt ]; then
+ python3 ./local/prepare_words.py \
+ --input-file $lang_dir/words_no_ids.txt \
+ --output-file $lang_dir/words.txt
+ fi
+
+ if [ ! -f $lang_dir/bbpe.model ]; then
+ ./local/train_bbpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/text
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bbpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bbpe.model
+ fi
+ done
+fi
+
diff --git a/egs/multi_zh_en/ASR/shared b/egs/multi_zh_en/ASR/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/multi_zh_en/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
new file mode 100644
index 000000000..be6e94472
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1,385 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=300.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
diff --git a/egs/multi_zh_en/ASR/zipformer/beam_search.py b/egs/multi_zh_en/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..8e2c0a65c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py
new file mode 100755
index 000000000..e21e8f052
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/decode.py
@@ -0,0 +1,851 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from multi_dataset import MultiDataset
+from train import add_model_arguments, get_model, get_params
+
+from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_2000/bbpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bbpe_2000",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-tal-csasr",
+ type=str2bool,
+ default=False,
+ help="Whether to use TAL-CSASR training data.",
+ )
+
+ parser.add_argument(
+ "--use-librispeech",
+ type=str2bool,
+ default=False,
+ help="Whether to use LibriSpeech training data.",
+ )
+
+ parser.add_argument(
+ "--use-aishell2",
+ type=str2bool,
+ default=False,
+ help="Whether to use Aishell-2 training data.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(
+ byte_encode(tokenize_by_CJK_char(supervisions["text"]))
+ ),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [tokenize_by_CJK_char(str(text)).split() for text in texts]
+ # print(texts)
+ # exit()
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ data_module = AsrDataModule(args)
+ multi_dataset = MultiDataset(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}"
+ )
+ return T > 0
+
+ test_sets_cuts = multi_dataset.test_cuts()
+
+ test_sets = test_sets_cuts.keys()
+ test_dl = [
+ data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
+ for cuts_name in test_sets
+ ]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ logging.info(f"Start decoding test set: {test_set}")
+
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/decoder.py b/egs/multi_zh_en/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..c2eaca671
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py
new file mode 120000
index 000000000..2962eb784
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx.py b/egs/multi_zh_en/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/export.py b/egs/multi_zh_en/ASR/zipformer/export.py
new file mode 100755
index 000000000..fbd9ce0dd
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/export.py
@@ -0,0 +1,541 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+(1) Export to torchscript model using torch.jit.script()
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1 \
+ --jit 1
+
+It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("jit_script.pt")`.
+
+Check ./jit_pretrained.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1 \
+ --jit 1
+
+It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
+You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
+
+Check ./jit_pretrained_streaming.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --causal 1 \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+- For non-streaming model:
+
+To use the generated file with `zipformer/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./zipformer/decode.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_2000/bpe.model
+
+- For streaming model:
+
+To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+
+ # simulated streaming decoding
+ ./zipformer/decode.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_2000/bpe.model
+
+ # chunk-wise streaming decoding
+ ./zipformer/streaming_decode.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_2000/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+- non-streaming model:
+https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
+ # You will find the pre-trained models in exp dir
+"""
+
+import argparse
+import logging
+import re
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from torch import Tensor, nn
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, str2bool
+
+
+def num_tokens(
+ token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
+) -> int:
+ """Return the number of tokens excluding those from
+ disambiguation symbols.
+
+ Caution:
+ 0 is not a token ID so it is excluded from the return value.
+ """
+ symbols = token_table.symbols
+ ans = []
+ for s in symbols:
+ if not disambig_pattern.match(s):
+ ans.append(token_table[s])
+ num_tokens = len(ans)
+ if 0 in ans:
+ num_tokens -= 1
+ return num_tokens
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=20,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bbpe_2000/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named jit_script.pt.
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+class EncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ """
+ x, x_lens = self.encoder_embed(features, feature_lengths)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ return encoder_out, encoder_out_lens
+
+
+class StreamingEncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ assert len(encoder.chunk_size) == 1, encoder.chunk_size
+ assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
+ self.chunk_size = encoder.chunk_size[0]
+ self.left_context_len = encoder.left_context_frames[0]
+
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ self.pad_length = 7 + 2 * 3
+
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """Streaming forward for encoder_embed and encoder.
+
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ states: a list of Tensors
+
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ chunk_size = self.chunk_size
+ left_context_len = self.left_context_len
+
+ cached_embed_left_pad = states[-2]
+ x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lengths,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = self.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = self.encoder.get_init_states(batch_size, device)
+
+ embed_states = self.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ # if torch.cuda.is_available():
+ # device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+
+ # Wrap encoder and encoder_embed as a module
+ if params.causal:
+ model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
+ chunk_size = model.encoder.chunk_size
+ left_context_len = model.encoder.left_context_len
+ filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
+ else:
+ model.encoder = EncoderModel(model.encoder, model.encoder_embed)
+ filename = "jit_script.pt"
+
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ model.save(str(params.exp_dir / filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py
new file mode 100755
index 000000000..68111fad7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) use the checkpoint exp_dir/epoch-xxx.pt
+./zipformer/generate_averaged_model.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp
+
+It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
+You can later load it by `torch.load("epoch-28-avg-15.pt")`.
+
+(2) use the checkpoint exp_dir/checkpoint-iter.pt
+./zipformer/generate_averaged_model.py \
+ --iter 22000 \
+ --avg 5 \
+ --exp-dir ./zipformer/exp
+
+It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
+You can later load it by `torch.load("iter-22000-avg-5.pt")`.
+"""
+
+
+import argparse
+from pathlib import Path
+
+import k2
+import torch
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ print("Script started")
+
+ device = torch.device("cpu")
+ print(f"Device: {device}")
+
+ symbol_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = symbol_table[""]
+ params.unk_id = symbol_table[""]
+ params.vocab_size = len(symbol_table)
+
+ print("About to create model")
+ model = get_model(params)
+
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ print(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ print(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+ print("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py
new file mode 120000
index 000000000..25108391f
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py
new file mode 120000
index 000000000..9a8da5844
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py
new file mode 120000
index 000000000..1962351e9
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/joiner.py b/egs/multi_zh_en/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/model.py b/egs/multi_zh_en/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py
new file mode 100644
index 000000000..1155a3dcc
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py
@@ -0,0 +1,247 @@
+# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Dict
+
+from lhotse import CutSet, load_manifest_lazy
+
+
+class MultiDataset:
+ def __init__(self, args: argparse.Namespace):
+ """
+ Args:
+ manifest_dir:
+ It is expected to contain the following files:
+ - aishell2_cuts_train.jsonl.gz
+ """
+ self.fbank_dir = Path(args.manifest_dir)
+ self.use_tal_csasr = args.use_tal_csasr
+ self.use_librispeech = args.use_librispeech
+ self.use_aishell2 = args.use_aishell2
+
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get multidataset train cuts")
+
+ # AISHELL-2
+ if self.use_aishell2:
+ logging.info("Loading Aishell-2 in lazy mode")
+ aishell_2_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
+ )
+
+ # TAL-CSASR
+ if self.use_tal_csasr:
+ logging.info("Loading TAL-CSASR in lazy mode")
+ tal_csasr_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz"
+ )
+
+ # LibriSpeech
+ if self.use_librispeech:
+ logging.info("Loading LibriSpeech in lazy mode")
+ train_clean_100_cuts = self.train_clean_100_cuts()
+ train_clean_360_cuts = self.train_clean_360_cuts()
+ train_other_500_cuts = self.train_other_500_cuts()
+
+ if self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
+ return CutSet.mux(
+ aishell_2_cuts,
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ tal_csasr_cuts,
+ weights=[
+ len(aishell_2_cuts),
+ len(train_clean_100_cuts),
+ len(train_clean_360_cuts),
+ len(train_other_500_cuts),
+ len(tal_csasr_cuts),
+ ],
+ )
+ elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
+ return CutSet.mux(
+ aishell_2_cuts,
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ weights=[
+ len(aishell_2_cuts),
+ len(train_clean_100_cuts),
+ len(train_clean_360_cuts),
+ len(train_other_500_cuts),
+ ],
+ )
+ elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2:
+ return CutSet.mux(
+ aishell_2_cuts,
+ tal_csasr_cuts,
+ weights=[
+ len(aishell_2_cuts),
+ len(tal_csasr_cuts),
+ ],
+ )
+ elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2:
+ return CutSet.mux(
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ tal_csasr_cuts,
+ weights=[
+ len(train_clean_100_cuts),
+ len(train_clean_360_cuts),
+ len(train_other_500_cuts),
+ len(tal_csasr_cuts),
+ ],
+ )
+ else:
+ raise NotImplementedError(
+ f"""Not implemented for
+ use_aishell2: {self.use_aishell2}
+ use_librispeech: {self.use_librispeech}
+ use_tal_csasr: {self.use_tal_csasr}"""
+ )
+
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get multidataset dev cuts")
+
+ # AISHELL-2
+ logging.info("Loading Aishell-2 DEV set in lazy mode")
+ aishell2_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
+ )
+
+ # LibriSpeech
+ dev_clean_cuts = self.dev_clean_cuts()
+ dev_other_cuts = self.dev_other_cuts()
+
+ logging.info("Loading TAL-CSASR set in lazy mode")
+ tal_csasr_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
+ )
+
+ return CutSet.mux(
+ aishell2_dev_cuts,
+ dev_clean_cuts,
+ dev_other_cuts,
+ tal_csasr_dev_cuts,
+ weights=[
+ len(aishell2_dev_cuts),
+ len(dev_clean_cuts),
+ len(dev_other_cuts),
+ len(tal_csasr_dev_cuts),
+ ],
+ )
+
+ def test_cuts(self) -> Dict[str, CutSet]:
+ logging.info("About to get multidataset test cuts")
+
+ # AISHELL-2
+ if self.use_aishell2:
+ logging.info("Loading Aishell-2 set in lazy mode")
+ aishell2_test_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
+ )
+ aishell2_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
+ )
+
+ # LibriSpeech
+ if self.use_librispeech:
+ test_clean_cuts = self.test_clean_cuts()
+ test_other_cuts = self.test_other_cuts()
+
+ logging.info("Loading TAL-CSASR set in lazy mode")
+ tal_csasr_test_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz"
+ )
+ tal_csasr_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
+ )
+
+ test_cuts = {
+ "tal_csasr_test": tal_csasr_test_cuts,
+ "tal_csasr_dev": tal_csasr_dev_cuts,
+ }
+
+ if self.use_aishell2:
+ test_cuts.update(
+ {
+ "aishell-2_test": aishell2_test_cuts,
+ "aishell-2_dev": aishell2_dev_cuts,
+ }
+ )
+ if self.use_librispeech:
+ test_cuts.update(
+ {
+ "librispeech_test_clean": test_clean_cuts,
+ "librispeech_test_other": test_other_cuts,
+ }
+ )
+ return test_cuts
+
+ @lru_cache()
+ def train_clean_100_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-100 cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_clean_360_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-360 cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_other_500_cuts(self) -> CutSet:
+ logging.info("About to get train-other-500 cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_clean_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_other_cuts(self) -> CutSet:
+ logging.info("About to get dev-other cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get test-clean cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get test-other cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz"
+ )
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_check.py b/egs/multi_zh_en/ASR/zipformer/onnx_check.py
new file mode 120000
index 000000000..f3dd42004
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_check.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py
new file mode 120000
index 000000000..0573b88c5
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_decode.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py
new file mode 120000
index 000000000..cfea104c2
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/optim.py b/egs/multi_zh_en/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py
new file mode 100755
index 000000000..676272e1f
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py
@@ -0,0 +1,378 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 23 \
+ --avg 1
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --causal 1 \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 23 \
+ --avg 1
+
+Usage of this script:
+
+- For non-streaming model:
+
+(1) greedy search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+- For streaming model:
+
+(1) greedy search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+You can also use `./zipformer/exp/epoch-xx.pt`.
+
+Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+ fast_beam_search_one_best,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from export import num_tokens
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to byte-level bpe model.""",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+
+ logging.info("Creating model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ # model forward
+ encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
+
+ hyps = []
+ msg = f"Using {params.method}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ s += f"{filename}:\n{hyp}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/scaling.py b/egs/multi_zh_en/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py
new file mode 120000
index 000000000..b1ed54557
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
new file mode 120000
index 000000000..13fd02a78
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_decode.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/subsampling.py b/egs/multi_zh_en/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py
new file mode 100755
index 000000000..310c8fe59
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/train.py
@@ -0,0 +1,1416 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from multi_dataset import MultiDataset
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import byte_encode, diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+ tokenize_by_CJK_char,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_2000/bbpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--use-tal-csasr",
+ type=str2bool,
+ default=False,
+ help="Whether to use TAL-CSASR training data.",
+ )
+
+ parser.add_argument(
+ "--use-librispeech",
+ type=str2bool,
+ default=False,
+ help="Whether to use LibriSpeech training data.",
+ )
+
+ parser.add_argument(
+ "--use-aishell2",
+ type=str2bool,
+ default=False,
+ help="Whether to use Aishell-2 training data.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ data_module = AsrDataModule(args)
+ multi_dataset = MultiDataset(args)
+
+ train_cuts = multi_dataset.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ def tokenize_and_encode_text(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = byte_encode(tokenize_by_CJK_char(text))
+ c.supervisions[0].text = text
+ return c
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ train_cuts = train_cuts.map(tokenize_and_encode_text)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = data_module.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = multi_dataset.dev_cuts()
+ valid_dl = data_module.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/zipformer.py b/egs/multi_zh_en/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
From 0622dea30deacf2680dcca0549f7a05c0b965066 Mon Sep 17 00:00:00 2001
From: Zengwei Yao
Date: Wed, 29 Nov 2023 21:28:38 +0800
Subject: [PATCH 008/123] Add a TTS recipe VITS on LJSpeech dataset (#1372)
* first commit
* replace phonimizer with g2p
* use Conformer as text encoder
* modify training script, clean codes
* rename directory
* convert text to tokens in data preparation stage
* fix tts_datamodule.py
* support onnx export and testing the exported onnx model
* add doc
* add README.md
* fix style
---
.flake8 | 2 +-
docs/source/recipes/TTS/index.rst | 7 +
docs/source/recipes/TTS/ljspeech/vits.rst | 113 +++
docs/source/recipes/index.rst | 3 +-
.../TTS/local/compute_spectrogram_ljspeech.py | 106 ++
.../TTS/local/display_manifest_statistics.py | 73 ++
egs/ljspeech/TTS/local/prepare_token_file.py | 104 ++
.../TTS/local/prepare_tokens_ljspeech.py | 59 ++
egs/ljspeech/TTS/local/validate_manifest.py | 70 ++
egs/ljspeech/TTS/prepare.sh | 117 +++
egs/ljspeech/TTS/shared/parse_options.sh | 1 +
egs/ljspeech/TTS/vits/README.md | 3 +
egs/ljspeech/TTS/vits/duration_predictor.py | 194 ++++
egs/ljspeech/TTS/vits/export-onnx.py | 261 +++++
egs/ljspeech/TTS/vits/flow.py | 312 ++++++
egs/ljspeech/TTS/vits/generator.py | 531 ++++++++++
egs/ljspeech/TTS/vits/hifigan.py | 933 ++++++++++++++++++
egs/ljspeech/TTS/vits/infer.py | 233 +++++
egs/ljspeech/TTS/vits/loss.py | 336 +++++++
.../TTS/vits/monotonic_align/__init__.py | 81 ++
.../TTS/vits/monotonic_align/core.pyx | 51 +
.../TTS/vits/monotonic_align/setup.py | 31 +
egs/ljspeech/TTS/vits/posterior_encoder.py | 117 +++
egs/ljspeech/TTS/vits/residual_coupling.py | 229 +++++
egs/ljspeech/TTS/vits/test_onnx.py | 123 +++
egs/ljspeech/TTS/vits/text_encoder.py | 662 +++++++++++++
egs/ljspeech/TTS/vits/tokenizer.py | 106 ++
egs/ljspeech/TTS/vits/train.py | 893 +++++++++++++++++
egs/ljspeech/TTS/vits/transform.py | 218 ++++
egs/ljspeech/TTS/vits/tts_datamodule.py | 325 ++++++
egs/ljspeech/TTS/vits/utils.py | 265 +++++
egs/ljspeech/TTS/vits/vits.py | 610 ++++++++++++
egs/ljspeech/TTS/vits/wavenet.py | 349 +++++++
pyproject.toml | 1 +
34 files changed, 7517 insertions(+), 2 deletions(-)
create mode 100644 docs/source/recipes/TTS/index.rst
create mode 100644 docs/source/recipes/TTS/ljspeech/vits.rst
create mode 100755 egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
create mode 100755 egs/ljspeech/TTS/local/display_manifest_statistics.py
create mode 100755 egs/ljspeech/TTS/local/prepare_token_file.py
create mode 100755 egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
create mode 100755 egs/ljspeech/TTS/local/validate_manifest.py
create mode 100755 egs/ljspeech/TTS/prepare.sh
create mode 120000 egs/ljspeech/TTS/shared/parse_options.sh
create mode 100644 egs/ljspeech/TTS/vits/README.md
create mode 100644 egs/ljspeech/TTS/vits/duration_predictor.py
create mode 100755 egs/ljspeech/TTS/vits/export-onnx.py
create mode 100644 egs/ljspeech/TTS/vits/flow.py
create mode 100644 egs/ljspeech/TTS/vits/generator.py
create mode 100644 egs/ljspeech/TTS/vits/hifigan.py
create mode 100755 egs/ljspeech/TTS/vits/infer.py
create mode 100644 egs/ljspeech/TTS/vits/loss.py
create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/__init__.py
create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/core.pyx
create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/setup.py
create mode 100644 egs/ljspeech/TTS/vits/posterior_encoder.py
create mode 100644 egs/ljspeech/TTS/vits/residual_coupling.py
create mode 100755 egs/ljspeech/TTS/vits/test_onnx.py
create mode 100644 egs/ljspeech/TTS/vits/text_encoder.py
create mode 100644 egs/ljspeech/TTS/vits/tokenizer.py
create mode 100755 egs/ljspeech/TTS/vits/train.py
create mode 100644 egs/ljspeech/TTS/vits/transform.py
create mode 100644 egs/ljspeech/TTS/vits/tts_datamodule.py
create mode 100644 egs/ljspeech/TTS/vits/utils.py
create mode 100644 egs/ljspeech/TTS/vits/vits.py
create mode 100644 egs/ljspeech/TTS/vits/wavenet.py
diff --git a/.flake8 b/.flake8
index 410cb5482..cf276d0ba 100644
--- a/.flake8
+++ b/.flake8
@@ -15,7 +15,7 @@ per-file-ignores =
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
egs/librispeech/ASR/zipformer/*.py: E501, E203
egs/librispeech/ASR/RESULTS.md: E999,
-
+ egs/ljspeech/TTS/vits/*.py: E501, E203
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605
diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst
new file mode 100644
index 000000000..aa891c072
--- /dev/null
+++ b/docs/source/recipes/TTS/index.rst
@@ -0,0 +1,7 @@
+TTS
+======
+
+.. toctree::
+ :maxdepth: 2
+
+ ljspeech/vits
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
new file mode 100644
index 000000000..385fd3c70
--- /dev/null
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -0,0 +1,113 @@
+VITS
+===============
+
+This tutorial shows you how to train an VITS model
+with the `LJSpeech `_ dataset.
+
+.. note::
+
+ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/ljspeech/TTS
+ $ ./prepare.sh
+
+To run stage 1 to stage 5, use
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 1 --stop_stage 5
+
+
+Build Monotonic Alignment Search
+--------------------------------
+
+.. code-block:: bash
+
+ $ cd vits/monotonic_align
+ $ python setup.py build_ext --inplace
+ $ cd ../../
+
+
+Training
+--------
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ $ ./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1000 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 500
+
+.. note::
+
+ You can adjust the hyper-parameters to control the size of the VITS model and
+ the training configurations. For more details, please run ``./vits/train.py --help``.
+
+.. note::
+
+ The training can take a long time (usually a couple of days).
+
+Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
+
+
+Inference
+---------
+
+The inference part uses checkpoints saved by the training part, so you have to run the
+training part first. It will save the ground-truth and generated wavs to the directory
+``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0"
+ $ ./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 500
+
+.. note::
+
+ For more details, please run ``./vits/infer.py --help``.
+
+
+Export models
+-------------
+
+Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
+``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+
+.. code-block:: bash
+
+ $ ./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+You can test the exported ONNX model with:
+
+.. code-block:: bash
+
+ $ ./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following link:
+
+ - ``_
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 7265e1cf6..8df61f0d0 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -2,7 +2,7 @@ Recipes
=======
This page contains various recipes in ``icefall``.
-Currently, only speech recognition recipes are provided.
+Currently, we provide recipes for speech recognition, language model, and speech synthesis.
We may add recipes for other tasks as well in the future.
@@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future.
Non-streaming-ASR/index
Streaming-ASR/index
RNN-LM/index
+ TTS/index
diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
new file mode 100755
index 000000000..97c9008fc
--- /dev/null
+++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LJSpeech dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated spectrogram features are saved in data/spectrogram.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ LilcomChunkyWriter,
+ Spectrogram,
+ SpectrogramConfig,
+ load_manifest,
+)
+from lhotse.audio import RecordingSet
+from lhotse.supervision import SupervisionSet
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_spectrogram_ljspeech():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/spectrogram")
+ num_jobs = min(4, os.cpu_count())
+
+ sampling_rate = 22050
+ frame_length = 1024 / sampling_rate # (in second)
+ frame_shift = 256 / sampling_rate # (in second)
+ use_fft_mag = True
+
+ prefix = "ljspeech"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ recordings = load_manifest(
+ src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
+ )
+ supervisions = load_manifest(
+ src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
+ )
+
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ use_fft_mag=use_fft_mag,
+ )
+ extractor = Spectrogram(config)
+
+ with get_executor() as ex: # Initialize the executor only once.
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{cuts_filename} already exists - skipping.")
+ return
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_spectrogram_ljspeech()
diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py
new file mode 100755
index 000000000..93f0044f0
--- /dev/null
+++ b/egs/ljspeech/TTS/local/display_manifest_statistics.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in vits/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz"
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Cut statistics:
+ ╒═══════════════════════════╤══════════╕
+ │ Cuts count: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Total duration (hh:mm:ss) │ 23:55:18 │
+ ├───────────────────────────┼──────────┤
+ │ mean │ 6.6 │
+ ├───────────────────────────┼──────────┤
+ │ std │ 2.2 │
+ ├───────────────────────────┼──────────┤
+ │ min │ 1.1 │
+ ├───────────────────────────┼──────────┤
+ │ 25% │ 5.0 │
+ ├───────────────────────────┼──────────┤
+ │ 50% │ 6.8 │
+ ├───────────────────────────┼──────────┤
+ │ 75% │ 8.4 │
+ ├───────────────────────────┼──────────┤
+ │ 99% │ 10.0 │
+ ├───────────────────────────┼──────────┤
+ │ 99.5% │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ 99.9% │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ max │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ Recordings available: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Features available: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Supervisions available: │ 13100 │
+ ╘═══════════════════════════╧══════════╛
+"""
diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py
new file mode 100755
index 000000000..df976804a
--- /dev/null
+++ b/egs/ljspeech/TTS/local/prepare_token_file.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict
+
+from lhotse import load_manifest
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-file",
+ type=Path,
+ default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
+ help="Path to the manifest file",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens.txt"),
+ help="Path to the tokens",
+ )
+
+ return parser.parse_args()
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_token2id(manifest_file: Path) -> Dict[str, int]:
+ """Return a dict that maps token to IDs."""
+ extra_tokens = [
+ "", # 0 for blank
+ "", # 1 for sos and eos symbols.
+ "", # 2 for OOV
+ ]
+ all_tokens = set()
+
+ cut_set = load_manifest(manifest_file)
+
+ for cut in cut_set:
+ # Each cut only contain one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ for t in cut.tokens:
+ all_tokens.add(t)
+
+ all_tokens = extra_tokens + list(all_tokens)
+
+ token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
+ return token2id
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ manifest_file = Path(args.manifest_file)
+ out_file = Path(args.tokens)
+
+ token2id = get_token2id(manifest_file)
+ write_mapping(out_file, token2id)
diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
new file mode 100755
index 000000000..fcd0137a0
--- /dev/null
+++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import logging
+from pathlib import Path
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from lhotse import CutSet, load_manifest
+
+
+def prepare_tokens_ljspeech():
+ output_dir = Path("data/spectrogram")
+ prefix = "ljspeech"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+ g2p = g2p_en.G2p()
+
+ new_cuts = []
+ for cut in cut_set:
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ text = cut.supervisions[0].normalized_text
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ cut.tokens = g2p(text)
+ new_cuts.append(cut)
+
+ new_cut_set = CutSet.from_cuts(new_cuts)
+ new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ prepare_tokens_ljspeech()
diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py
new file mode 100755
index 000000000..68159ae03
--- /dev/null
+++ b/egs/ljspeech/TTS/local/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet), type(cut_set)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
new file mode 100755
index 000000000..8ee40896e
--- /dev/null
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -0,0 +1,117 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=1
+stage=-1
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # The directory $dl_dir/LJSpeech-1.1 will contain:
+ # - wavs, which contains the audio files
+ # - metadata.csv, which provides the transcript text for each audio clip
+
+ # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink
+ #
+ # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1
+ #
+ if [ ! -d $dl_dir/LJSpeech-1.1 ]; then
+ lhotse download ljspeech $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LJSpeech manifest"
+ # We assume that you have downloaded the LJSpeech corpus
+ # to $dl_dir/LJSpeech
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.ljspeech.done ]; then
+ lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
+ touch data/manifests/.ljspeech.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute spectrogram for LJSpeech"
+ mkdir -p data/spectrogram
+ if [ ! -e data/spectrogram/.ljspeech.done ]; then
+ ./local/compute_spectrogram_ljspeech.py
+ touch data/spectrogram/.ljspeech.done
+ fi
+
+ if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then
+ log "Validating data/spectrogram for LJSpeech"
+ python3 ./local/validate_manifest.py \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz
+ touch data/spectrogram/.ljspeech-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare phoneme tokens for LJSpeech"
+ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
+ ./local/prepare_tokens_ljspeech.py
+ mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz
+ touch data/spectrogram/.ljspeech_with_token.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
+ if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
+ lhotse subset --last 600 \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
+ lhotse subset --first 100 \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_valid.jsonl.gz
+ lhotse subset --last 500 \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_test.jsonl.gz
+
+ rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
+
+ n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
+ lhotse subset --first $n \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_train.jsonl.gz
+ touch data/spectrogram/.ljspeech_split.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate token file"
+ # We assume you have installed g2p_en and espnet_tts_frontend.
+ # If not, please install them with:
+ # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
+ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
+ if [ ! -e data/tokens.txt ]; then
+ ./local/prepare_token_file.py \
+ --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
+ --tokens data/tokens.txt
+ fi
+fi
+
+
diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh
new file mode 120000
index 000000000..e4665e7de
--- /dev/null
+++ b/egs/ljspeech/TTS/shared/parse_options.sh
@@ -0,0 +1 @@
+../../../librispeech/ASR/shared/parse_options.sh
\ No newline at end of file
diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md
new file mode 100644
index 000000000..1141326b9
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/README.md
@@ -0,0 +1,3 @@
+See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
+
+Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29.
diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py
new file mode 100644
index 000000000..c29a28479
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/duration_predictor.py
@@ -0,0 +1,194 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Stochastic duration predictor modules in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from flow import (
+ ConvFlow,
+ DilatedDepthSeparableConv,
+ ElementwiseAffineFlow,
+ FlipFlow,
+ LogFlow,
+)
+
+
+class StochasticDurationPredictor(torch.nn.Module):
+ """Stochastic duration predictor module.
+
+ This is a module of stochastic duration predictor described in `Conditional
+ Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+
+ """
+
+ def __init__(
+ self,
+ channels: int = 192,
+ kernel_size: int = 3,
+ dropout_rate: float = 0.5,
+ flows: int = 4,
+ dds_conv_layers: int = 3,
+ global_channels: int = -1,
+ ):
+ """Initialize StochasticDurationPredictor module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): Kernel size.
+ dropout_rate (float): Dropout rate.
+ flows (int): Number of flows.
+ dds_conv_layers (int): Number of conv layers in DDS conv.
+ global_channels (int): Number of global conditioning channels.
+
+ """
+ super().__init__()
+
+ self.pre = torch.nn.Conv1d(channels, channels, 1)
+ self.dds = DilatedDepthSeparableConv(
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ dropout_rate=dropout_rate,
+ )
+ self.proj = torch.nn.Conv1d(channels, channels, 1)
+
+ self.log_flow = LogFlow()
+ self.flows = torch.nn.ModuleList()
+ self.flows += [ElementwiseAffineFlow(2)]
+ for i in range(flows):
+ self.flows += [
+ ConvFlow(
+ 2,
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ )
+ ]
+ self.flows += [FlipFlow()]
+
+ self.post_pre = torch.nn.Conv1d(1, channels, 1)
+ self.post_dds = DilatedDepthSeparableConv(
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ dropout_rate=dropout_rate,
+ )
+ self.post_proj = torch.nn.Conv1d(channels, channels, 1)
+ self.post_flows = torch.nn.ModuleList()
+ self.post_flows += [ElementwiseAffineFlow(2)]
+ for i in range(flows):
+ self.post_flows += [
+ ConvFlow(
+ 2,
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ )
+ ]
+ self.post_flows += [FlipFlow()]
+
+ if global_channels > 0:
+ self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ w: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ noise_scale: float = 1.0,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T_text).
+ x_mask (Tensor): Mask tensor (B, 1, T_text).
+ w (Optional[Tensor]): Duration tensor (B, 1, T_text).
+ g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
+ inverse (bool): Whether to inverse the flow.
+ noise_scale (float): Noise scale value.
+
+ Returns:
+ Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
+ If inverse, log-duration tensor (B, 1, T_text).
+
+ """
+ x = x.detach() # stop gradient
+ x = self.pre(x)
+ if g is not None:
+ x = x + self.global_conv(g.detach()) # stop gradient
+ x = self.dds(x, x_mask)
+ x = self.proj(x) * x_mask
+
+ if not inverse:
+ assert w is not None, "w must be provided."
+ h_w = self.post_pre(w)
+ h_w = self.post_dds(h_w, x_mask)
+ h_w = self.post_proj(h_w) * x_mask
+ e_q = (
+ torch.randn(
+ w.size(0),
+ 2,
+ w.size(2),
+ ).to(device=x.device, dtype=x.dtype)
+ * x_mask
+ )
+ z_q = e_q
+ logdet_tot_q = 0.0
+ for flow in self.post_flows:
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
+ logdet_tot_q += logdet_q
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
+ u = torch.sigmoid(z_u) * x_mask
+ z0 = (w - u) * x_mask
+ logdet_tot_q += torch.sum(
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
+ )
+ logq = (
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
+ - logdet_tot_q
+ )
+
+ logdet_tot = 0
+ z0, logdet = self.log_flow(z0, x_mask)
+ logdet_tot += logdet
+ z = torch.cat([z0, z1], 1)
+ for flow in self.flows:
+ z, logdet = flow(z, x_mask, g=x, inverse=inverse)
+ logdet_tot = logdet_tot + logdet
+ nll = (
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
+ - logdet_tot
+ )
+ return nll + logq # (B,)
+ else:
+ flows = list(reversed(self.flows))
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
+ z = (
+ torch.randn(
+ x.size(0),
+ 2,
+ x.size(2),
+ ).to(device=x.device, dtype=x.dtype)
+ * noise_scale
+ )
+ for flow in flows:
+ z = flow(z, x_mask, g=x, inverse=inverse)
+ z0, z1 = z.split(1, 1)
+ logw = z0
+ return logw
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
new file mode 100755
index 000000000..154de4bf4
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -0,0 +1,261 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script exports a VITS model from PyTorch to ONNX.
+
+Export the model to ONNX:
+./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+It will generate two files inside vits/exp:
+ - vits-epoch-1000.onnx
+ - vits-epoch-1000.int8.onnx (quantizated model)
+
+See ./test_onnx.py for how to use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from tokenizer import Tokenizer
+from train import get_model, get_params
+
+from icefall.checkpoint import load_checkpoint
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxModel(nn.Module):
+ """A wrapper for VITS generator."""
+
+ def __init__(self, model: nn.Module):
+ """
+ Args:
+ model:
+ A VITS generator.
+ frame_shift:
+ The frame shift in samples.
+ """
+ super().__init__()
+ self.model = model
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of VITS.inference_batch
+
+ Args:
+ tokens:
+ Input text token indexes (1, T_text)
+ tokens_lens:
+ Number of tokens of shape (1,)
+ noise_scale (float):
+ Noise scale parameter for flow.
+ noise_scale_dur (float):
+ Noise scale parameter for duration predictor.
+ alpha (float):
+ Alpha parameter to control the speed of generated speech.
+
+ Returns:
+ Return a tuple containing:
+ - audio, generated wavform tensor, (B, T_wav)
+ """
+ audio, _, _ = self.model.inference(
+ text=tokens,
+ text_lengths=tokens_lens,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ )
+ return audio
+
+
+def export_model_onnx(
+ model: nn.Module,
+ model_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given generator model to ONNX format.
+ The exported model has one input:
+
+ - tokens, a tensor of shape (1, T_text); dtype is torch.int64
+
+ and it has one output:
+
+ - audio, a tensor of shape (1, T'); dtype is torch.float32
+
+ Args:
+ model:
+ The VITS generator.
+ model_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
+ noise_scale = torch.tensor([1], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([1], dtype=torch.float32)
+ alpha = torch.tensor([1], dtype=torch.float32)
+
+ torch.onnx.export(
+ model,
+ (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
+ model_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
+ output_names=["audio"],
+ dynamic_axes={
+ "tokens": {0: "N", 1: "T"},
+ "tokens_lens": {0: "N"},
+ "audio": {0: "N", 1: "T"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "VITS",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "VITS generator",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=model_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model = model.generator
+ model.to("cpu")
+ model.eval()
+
+ model = OnnxModel(model=model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"generator parameters: {num_param}")
+
+ suffix = f"epoch-{params.epoch}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ model_filename = params.exp_dir / f"vits-{suffix}.onnx"
+ export_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported generator to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ weight_type=QuantType.QUInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py
new file mode 100644
index 000000000..206bd5e3e
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/flow.py
@@ -0,0 +1,312 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Basic Flow modules used in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+
+from transform import piecewise_rational_quadratic_transform
+
+
+class FlipFlow(torch.nn.Module):
+ """Flip flow module."""
+
+ def forward(
+ self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Flipped tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ x = torch.flip(x, [1])
+ if not inverse:
+ logdet = x.new_zeros(x.size(0))
+ return x, logdet
+ else:
+ return x
+
+
+class LogFlow(torch.nn.Module):
+ """Log flow module."""
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ inverse: bool = False,
+ eps: float = 1e-5,
+ **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_mask (Tensor): Mask tensor (B, 1, T).
+ inverse (bool): Whether to inverse the flow.
+ eps (float): Epsilon for log.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ if not inverse:
+ y = torch.log(torch.clamp_min(x, eps)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class ElementwiseAffineFlow(torch.nn.Module):
+ """Elementwise affine flow module."""
+
+ def __init__(self, channels: int):
+ """Initialize ElementwiseAffineFlow module.
+
+ Args:
+ channels (int): Number of channels.
+
+ """
+ super().__init__()
+ self.channels = channels
+ self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
+ self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ if not inverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class Transpose(torch.nn.Module):
+ """Transpose module for torch.nn.Sequential()."""
+
+ def __init__(self, dim1: int, dim2: int):
+ """Initialize Transpose module."""
+ super().__init__()
+ self.dim1 = dim1
+ self.dim2 = dim2
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Transpose."""
+ return x.transpose(self.dim1, self.dim2)
+
+
+class DilatedDepthSeparableConv(torch.nn.Module):
+ """Dilated depth-separable conv module."""
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ layers: int,
+ dropout_rate: float = 0.0,
+ eps: float = 1e-5,
+ ):
+ """Initialize DilatedDepthSeparableConv module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): Kernel size.
+ layers (int): Number of layers.
+ dropout_rate (float): Dropout rate.
+ eps (float): Epsilon for layer norm.
+
+ """
+ super().__init__()
+
+ self.convs = torch.nn.ModuleList()
+ for i in range(layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ ),
+ Transpose(1, 2),
+ torch.nn.LayerNorm(
+ channels,
+ eps=eps,
+ elementwise_affine=True,
+ ),
+ Transpose(1, 2),
+ torch.nn.GELU(),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ 1,
+ ),
+ Transpose(1, 2),
+ torch.nn.LayerNorm(
+ channels,
+ eps=eps,
+ elementwise_affine=True,
+ ),
+ Transpose(1, 2),
+ torch.nn.GELU(),
+ torch.nn.Dropout(dropout_rate),
+ )
+ ]
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_mask (Tensor): Mask tensor (B, 1, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+
+ """
+ if g is not None:
+ x = x + g
+ for f in self.convs:
+ y = f(x * x_mask)
+ x = x + y
+ return x * x_mask
+
+
+class ConvFlow(torch.nn.Module):
+ """Convolutional flow module."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_channels: int,
+ kernel_size: int,
+ layers: int,
+ bins: int = 10,
+ tail_bound: float = 5.0,
+ ):
+ """Initialize ConvFlow module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size.
+ layers (int): Number of layers.
+ bins (int): Number of bins.
+ tail_bound (float): Tail bound value.
+
+ """
+ super().__init__()
+ self.half_channels = in_channels // 2
+ self.hidden_channels = hidden_channels
+ self.bins = bins
+ self.tail_bound = tail_bound
+
+ self.input_conv = torch.nn.Conv1d(
+ self.half_channels,
+ hidden_channels,
+ 1,
+ )
+ self.dds_conv = DilatedDepthSeparableConv(
+ hidden_channels,
+ kernel_size,
+ layers,
+ dropout_rate=0.0,
+ )
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels * (bins * 3 - 1),
+ 1,
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_mask (Tensor): Mask tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ xa, xb = x.split(x.size(1) // 2, 1)
+ h = self.input_conv(xa)
+ h = self.dds_conv(h, x_mask, g=g)
+ h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
+
+ b, c, t = xa.shape
+ # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
+
+ # TODO(kan-bayashi): Understand this calculation
+ denom = math.sqrt(self.hidden_channels)
+ unnorm_widths = h[..., : self.bins] / denom
+ unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
+ unnorm_derivatives = h[..., 2 * self.bins :]
+ xb, logdet_abs = piecewise_rational_quadratic_transform(
+ xb,
+ unnorm_widths,
+ unnorm_heights,
+ unnorm_derivatives,
+ inverse=inverse,
+ tails="linear",
+ tail_bound=self.tail_bound,
+ )
+ x = torch.cat([xa, xb], 1) * x_mask
+ logdet = torch.sum(logdet_abs * x_mask, [1, 2])
+ if not inverse:
+ return x, logdet
+ else:
+ return x
diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py
new file mode 100644
index 000000000..efb0e254c
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/generator.py
@@ -0,0 +1,531 @@
+# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Generator module in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+
+import math
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from icefall.utils import make_pad_mask
+
+from duration_predictor import StochasticDurationPredictor
+from hifigan import HiFiGANGenerator
+from posterior_encoder import PosteriorEncoder
+from residual_coupling import ResidualAffineCouplingBlock
+from text_encoder import TextEncoder
+from utils import get_random_segments
+
+
+class VITSGenerator(torch.nn.Module):
+ """Generator module in VITS, `Conditional Variational Autoencoder
+ with Adversarial Learning for End-to-End Text-to-Speech`.
+ """
+
+ def __init__(
+ self,
+ vocabs: int,
+ aux_channels: int = 513,
+ hidden_channels: int = 192,
+ spks: Optional[int] = None,
+ langs: Optional[int] = None,
+ spk_embed_dim: Optional[int] = None,
+ global_channels: int = -1,
+ segment_size: int = 32,
+ text_encoder_attention_heads: int = 2,
+ text_encoder_ffn_expand: int = 4,
+ text_encoder_cnn_module_kernel: int = 5,
+ text_encoder_blocks: int = 6,
+ text_encoder_dropout_rate: float = 0.1,
+ decoder_kernel_size: int = 7,
+ decoder_channels: int = 512,
+ decoder_upsample_scales: List[int] = [8, 8, 2, 2],
+ decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
+ decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
+ decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ use_weight_norm_in_decoder: bool = True,
+ posterior_encoder_kernel_size: int = 5,
+ posterior_encoder_layers: int = 16,
+ posterior_encoder_stacks: int = 1,
+ posterior_encoder_base_dilation: int = 1,
+ posterior_encoder_dropout_rate: float = 0.0,
+ use_weight_norm_in_posterior_encoder: bool = True,
+ flow_flows: int = 4,
+ flow_kernel_size: int = 5,
+ flow_base_dilation: int = 1,
+ flow_layers: int = 4,
+ flow_dropout_rate: float = 0.0,
+ use_weight_norm_in_flow: bool = True,
+ use_only_mean_in_flow: bool = True,
+ stochastic_duration_predictor_kernel_size: int = 3,
+ stochastic_duration_predictor_dropout_rate: float = 0.5,
+ stochastic_duration_predictor_flows: int = 4,
+ stochastic_duration_predictor_dds_conv_layers: int = 3,
+ ):
+ """Initialize VITS generator module.
+
+ Args:
+ vocabs (int): Input vocabulary size.
+ aux_channels (int): Number of acoustic feature channels.
+ hidden_channels (int): Number of hidden channels.
+ spks (Optional[int]): Number of speakers. If set to > 1, assume that the
+ sids will be provided as the input and use sid embedding layer.
+ langs (Optional[int]): Number of languages. If set to > 1, assume that the
+ lids will be provided as the input and use sid embedding layer.
+ spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
+ assume that spembs will be provided as the input.
+ global_channels (int): Number of global conditioning channels.
+ segment_size (int): Segment size for decoder.
+ text_encoder_attention_heads (int): Number of heads in conformer block
+ of text encoder.
+ text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
+ of text encoder.
+ text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
+ text_encoder_blocks (int): Number of conformer blocks in text encoder.
+ text_encoder_dropout_rate (float): Dropout rate in conformer block of
+ text encoder.
+ decoder_kernel_size (int): Decoder kernel size.
+ decoder_channels (int): Number of decoder initial channels.
+ decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
+ decoder_upsample_kernel_sizes (List[int]): List of kernel size for
+ upsampling layers in decoder.
+ decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
+ in decoder.
+ decoder_resblock_dilations (List[List[int]]): List of list of dilations for
+ resblocks in decoder.
+ use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
+ decoder.
+ posterior_encoder_kernel_size (int): Posterior encoder kernel size.
+ posterior_encoder_layers (int): Number of layers of posterior encoder.
+ posterior_encoder_stacks (int): Number of stacks of posterior encoder.
+ posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
+ posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
+ use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
+ normalization in posterior encoder.
+ flow_flows (int): Number of flows in flow.
+ flow_kernel_size (int): Kernel size in flow.
+ flow_base_dilation (int): Base dilation in flow.
+ flow_layers (int): Number of layers in flow.
+ flow_dropout_rate (float): Dropout rate in flow
+ use_weight_norm_in_flow (bool): Whether to apply weight normalization in
+ flow.
+ use_only_mean_in_flow (bool): Whether to use only mean in flow.
+ stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
+ duration predictor.
+ stochastic_duration_predictor_dropout_rate (float): Dropout rate in
+ stochastic duration predictor.
+ stochastic_duration_predictor_flows (int): Number of flows in stochastic
+ duration predictor.
+ stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
+ layers in stochastic duration predictor.
+
+ """
+ super().__init__()
+ self.segment_size = segment_size
+ self.text_encoder = TextEncoder(
+ vocabs=vocabs,
+ d_model=hidden_channels,
+ num_heads=text_encoder_attention_heads,
+ dim_feedforward=hidden_channels * text_encoder_ffn_expand,
+ cnn_module_kernel=text_encoder_cnn_module_kernel,
+ num_layers=text_encoder_blocks,
+ dropout=text_encoder_dropout_rate,
+ )
+ self.decoder = HiFiGANGenerator(
+ in_channels=hidden_channels,
+ out_channels=1,
+ channels=decoder_channels,
+ global_channels=global_channels,
+ kernel_size=decoder_kernel_size,
+ upsample_scales=decoder_upsample_scales,
+ upsample_kernel_sizes=decoder_upsample_kernel_sizes,
+ resblock_kernel_sizes=decoder_resblock_kernel_sizes,
+ resblock_dilations=decoder_resblock_dilations,
+ use_weight_norm=use_weight_norm_in_decoder,
+ )
+ self.posterior_encoder = PosteriorEncoder(
+ in_channels=aux_channels,
+ out_channels=hidden_channels,
+ hidden_channels=hidden_channels,
+ kernel_size=posterior_encoder_kernel_size,
+ layers=posterior_encoder_layers,
+ stacks=posterior_encoder_stacks,
+ base_dilation=posterior_encoder_base_dilation,
+ global_channels=global_channels,
+ dropout_rate=posterior_encoder_dropout_rate,
+ use_weight_norm=use_weight_norm_in_posterior_encoder,
+ )
+ self.flow = ResidualAffineCouplingBlock(
+ in_channels=hidden_channels,
+ hidden_channels=hidden_channels,
+ flows=flow_flows,
+ kernel_size=flow_kernel_size,
+ base_dilation=flow_base_dilation,
+ layers=flow_layers,
+ global_channels=global_channels,
+ dropout_rate=flow_dropout_rate,
+ use_weight_norm=use_weight_norm_in_flow,
+ use_only_mean=use_only_mean_in_flow,
+ )
+ # TODO(kan-bayashi): Add deterministic version as an option
+ self.duration_predictor = StochasticDurationPredictor(
+ channels=hidden_channels,
+ kernel_size=stochastic_duration_predictor_kernel_size,
+ dropout_rate=stochastic_duration_predictor_dropout_rate,
+ flows=stochastic_duration_predictor_flows,
+ dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
+ global_channels=global_channels,
+ )
+
+ self.upsample_factor = int(np.prod(decoder_upsample_scales))
+ self.spks = None
+ if spks is not None and spks > 1:
+ assert global_channels > 0
+ self.spks = spks
+ self.global_emb = torch.nn.Embedding(spks, global_channels)
+ self.spk_embed_dim = None
+ if spk_embed_dim is not None and spk_embed_dim > 0:
+ assert global_channels > 0
+ self.spk_embed_dim = spk_embed_dim
+ self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
+ self.langs = None
+ if langs is not None and langs > 1:
+ assert global_channels > 0
+ self.langs = langs
+ self.lang_emb = torch.nn.Embedding(langs, global_channels)
+
+ # delayed import
+ from monotonic_align import maximum_path
+
+ self.maximum_path = maximum_path
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ ],
+ ]:
+ """Calculate forward propagation.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, aux_channels, T_feats).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
+ Tensor: Duration negative log-likelihood (NLL) tensor (B,).
+ Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
+ Tensor: Segments start index tensor (B,).
+ Tensor: Text mask tensor (B, 1, T_text).
+ Tensor: Feature mask tensor (B, 1, T_feats).
+ tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ - Tensor: Posterior encoder hidden representation (B, H, T_feats).
+ - Tensor: Flow hidden representation (B, H, T_feats).
+ - Tensor: Expanded text encoder projected mean (B, H, T_feats).
+ - Tensor: Expanded text encoder projected scale (B, H, T_feats).
+ - Tensor: Posterior encoder projected mean (B, H, T_feats).
+ - Tensor: Posterior encoder projected scale (B, H, T_feats).
+
+ """
+ # forward text encoder
+ x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
+
+ # calculate global conditioning
+ g = None
+ if self.spks is not None:
+ # speaker one-hot vector embedding: (B, global_channels, 1)
+ g = self.global_emb(sids.view(-1)).unsqueeze(-1)
+ if self.spk_embed_dim is not None:
+ # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
+ g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+ if self.langs is not None:
+ # language one-hot vector embedding: (B, global_channels, 1)
+ g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+
+ # forward posterior encoder
+ z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
+
+ # forward flow
+ z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
+
+ # monotonic alignment search
+ with torch.no_grad():
+ # negative cross-entropy
+ s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
+ # (B, 1, T_text)
+ neg_x_ent_1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2),
+ s_p_sq_r,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_3 = torch.matmul(
+ z_p.transpose(1, 2),
+ (m_p * s_p_sq_r),
+ )
+ # (B, 1, T_text)
+ neg_x_ent_4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, T_text)
+ neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
+ # (B, 1, T_feats, T_text)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ # monotonic attention weight: (B, 1, T_feats, T_text)
+ attn = (
+ self.maximum_path(
+ neg_x_ent,
+ attn_mask.squeeze(1),
+ )
+ .unsqueeze(1)
+ .detach()
+ )
+
+ # forward duration predictor
+ w = attn.sum(2) # (B, 1, T_text)
+ dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
+ dur_nll = dur_nll / torch.sum(x_mask)
+
+ # expand the length to match with the feature sequence
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
+
+ # get random segments
+ z_segments, z_start_idxs = get_random_segments(
+ z,
+ feats_lengths,
+ self.segment_size,
+ )
+
+ # forward decoder with random segments
+ wav = self.decoder(z_segments, g=g)
+
+ return (
+ wav,
+ dur_nll,
+ attn,
+ z_start_idxs,
+ x_mask,
+ y_mask,
+ (z, z_p, m_p, logs_p, m_q, logs_q),
+ )
+
+ def inference(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: Optional[torch.Tensor] = None,
+ feats_lengths: Optional[torch.Tensor] = None,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ dur: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference.
+
+ Args:
+ text (Tensor): Input text index tensor (B, T_text,).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+ dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
+ skip the prediction of durations (i.e., teacher forcing).
+ noise_scale (float): Noise scale parameter for flow.
+ noise_scale_dur (float): Noise scale parameter for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length of acoustic feature sequence.
+ use_teacher_forcing (bool): Whether to use teacher forcing.
+
+ Returns:
+ Tensor: Generated waveform tensor (B, T_wav).
+ Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
+ Tensor: Duration tensor (B, T_text).
+
+ """
+ # encoder
+ x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
+ x_mask = x_mask.to(x.dtype)
+ g = None
+ if self.spks is not None:
+ # (B, global_channels, 1)
+ g = self.global_emb(sids.view(-1)).unsqueeze(-1)
+ if self.spk_embed_dim is not None:
+ # (B, global_channels, 1)
+ g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+ if self.langs is not None:
+ # (B, global_channels, 1)
+ g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+
+ if use_teacher_forcing:
+ # forward posterior encoder
+ z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
+
+ # forward flow
+ z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
+
+ # monotonic alignment search
+ s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
+ # (B, 1, T_text)
+ neg_x_ent_1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2),
+ s_p_sq_r,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_3 = torch.matmul(
+ z_p.transpose(1, 2),
+ (m_p * s_p_sq_r),
+ )
+ # (B, 1, T_text)
+ neg_x_ent_4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, T_text)
+ neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
+ # (B, 1, T_feats, T_text)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ # monotonic attention weight: (B, 1, T_feats, T_text)
+ attn = self.maximum_path(
+ neg_x_ent,
+ attn_mask.squeeze(1),
+ ).unsqueeze(1)
+ dur = attn.sum(2) # (B, 1, T_text)
+
+ # forward decoder with random segments
+ wav = self.decoder(z * y_mask, g=g)
+ else:
+ # duration
+ if dur is None:
+ logw = self.duration_predictor(
+ x,
+ x_mask,
+ g=g,
+ inverse=True,
+ noise_scale=noise_scale_dur,
+ )
+ w = torch.exp(logw) * x_mask * alpha
+ dur = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
+ y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
+ y_mask = y_mask.to(x.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = self._generate_path(dur, attn_mask)
+
+ # expand the length to match with the feature sequence
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ m_p = torch.matmul(
+ attn.squeeze(1),
+ m_p.transpose(1, 2),
+ ).transpose(1, 2)
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ logs_p = torch.matmul(
+ attn.squeeze(1),
+ logs_p.transpose(1, 2),
+ ).transpose(1, 2)
+
+ # decoder
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, inverse=True)
+ wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
+
+ return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
+
+ def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Generate path a.k.a. monotonic attention.
+
+ Args:
+ dur (Tensor): Duration tensor (B, 1, T_text).
+ mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
+
+ Returns:
+ Tensor: Path tensor (B, 1, T_feats, T_text).
+
+ """
+ b, _, t_y, t_x = mask.shape
+ cum_dur = torch.cumsum(dur, -1)
+ cum_dur_flat = cum_dur.view(b * t_x)
+ path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
+ path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
+ # path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
+ path = path.view(b, t_x, t_y).to(dtype=torch.float)
+ # path will be like (t_x = 3, t_y = 5):
+ # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
+ # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
+ # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
+ path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
+ # path = path.to(dtype=mask.dtype)
+ return path.unsqueeze(1).transpose(2, 3) * mask
diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py
new file mode 100644
index 000000000..589ac30f6
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/hifigan.py
@@ -0,0 +1,933 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HiFi-GAN Modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import copy
+import logging
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+class HiFiGANGenerator(torch.nn.Module):
+ """HiFiGAN generator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 80,
+ out_channels: int = 1,
+ channels: int = 512,
+ global_channels: int = -1,
+ kernel_size: int = 7,
+ upsample_scales: List[int] = [8, 8, 2, 2],
+ upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
+ resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ use_additional_convs: bool = True,
+ bias: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ ):
+ """Initialize HiFiGANGenerator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ channels (int): Number of hidden representation channels.
+ global_channels (int): Number of global conditioning channels.
+ kernel_size (int): Kernel size of initial and final conv layer.
+ upsample_scales (List[int]): List of upsampling scales.
+ upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
+ resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
+ resblock_dilations (List[List[int]]): List of list of dilations for residual
+ blocks.
+ use_additional_convs (bool): Whether to use additional conv layers in
+ residual blocks.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+
+ """
+ super().__init__()
+
+ # check hyperparameters are valid
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
+ assert len(upsample_scales) == len(upsample_kernel_sizes)
+ assert len(resblock_dilations) == len(resblock_kernel_sizes)
+
+ # define modules
+ self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
+ self.num_upsamples = len(upsample_kernel_sizes)
+ self.num_blocks = len(resblock_kernel_sizes)
+ self.input_conv = torch.nn.Conv1d(
+ in_channels,
+ channels,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.upsamples = torch.nn.ModuleList()
+ self.blocks = torch.nn.ModuleList()
+ for i in range(len(upsample_kernel_sizes)):
+ assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
+ self.upsamples += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.ConvTranspose1d(
+ channels // (2**i),
+ channels // (2 ** (i + 1)),
+ upsample_kernel_sizes[i],
+ upsample_scales[i],
+ padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
+ output_padding=upsample_scales[i] % 2,
+ ),
+ )
+ ]
+ for j in range(len(resblock_kernel_sizes)):
+ self.blocks += [
+ ResidualBlock(
+ kernel_size=resblock_kernel_sizes[j],
+ channels=channels // (2 ** (i + 1)),
+ dilations=resblock_dilations[j],
+ bias=bias,
+ use_additional_convs=use_additional_convs,
+ nonlinear_activation=nonlinear_activation,
+ nonlinear_activation_params=nonlinear_activation_params,
+ )
+ ]
+ self.output_conv = torch.nn.Sequential(
+ # NOTE(kan-bayashi): follow official implementation but why
+ # using different slope parameter here? (0.1 vs. 0.01)
+ torch.nn.LeakyReLU(),
+ torch.nn.Conv1d(
+ channels // (2 ** (i + 1)),
+ out_channels,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ ),
+ torch.nn.Tanh(),
+ )
+ if global_channels > 0:
+ self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # reset parameters
+ self.reset_parameters()
+
+ def forward(
+ self, c: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, in_channels, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T).
+
+ """
+ c = self.input_conv(c)
+ if g is not None:
+ c = c + self.global_conv(g)
+ for i in range(self.num_upsamples):
+ c = self.upsamples[i](c)
+ cs = 0.0 # initialize
+ for j in range(self.num_blocks):
+ cs += self.blocks[i * self.num_blocks + j](c)
+ c = cs / self.num_blocks
+ c = self.output_conv(c)
+
+ return c
+
+ def reset_parameters(self):
+ """Reset parameters.
+
+ This initialization follows the official implementation manner.
+ https://github.com/jik876/hifi-gan/blob/master/models.py
+
+ """
+
+ def _reset_parameters(m: torch.nn.Module):
+ if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
+ m.weight.data.normal_(0.0, 0.01)
+ logging.debug(f"Reset parameters in {m}.")
+
+ self.apply(_reset_parameters)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m: torch.nn.Module):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(
+ m, torch.nn.ConvTranspose1d
+ ):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def inference(
+ self, c: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Perform inference.
+
+ Args:
+ c (torch.Tensor): Input tensor (T, in_channels).
+ g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (T ** upsample_factor, out_channels).
+
+ """
+ if g is not None:
+ g = g.unsqueeze(0)
+ c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
+ return c.squeeze(0).transpose(1, 0)
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN."""
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ channels: int = 512,
+ dilations: List[int] = [1, 3, 5],
+ bias: bool = True,
+ use_additional_convs: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ channels (int): Number of channels for convolution layer.
+ dilations (List[int]): List of dilation factors.
+ use_additional_convs (bool): Whether to use additional convolution layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+
+ """
+ super().__init__()
+ self.use_additional_convs = use_additional_convs
+ self.convs1 = torch.nn.ModuleList()
+ if use_additional_convs:
+ self.convs2 = torch.nn.ModuleList()
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
+ for dilation in dilations:
+ self.convs1 += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ bias=bias,
+ padding=(kernel_size - 1) // 2 * dilation,
+ ),
+ )
+ ]
+ if use_additional_convs:
+ self.convs2 += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ bias=bias,
+ padding=(kernel_size - 1) // 2,
+ ),
+ )
+ ]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+
+ """
+ for idx in range(len(self.convs1)):
+ xt = self.convs1[idx](x)
+ if self.use_additional_convs:
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+
+class HiFiGANPeriodDiscriminator(torch.nn.Module):
+ """HiFiGAN period discriminator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ period: int = 3,
+ kernel_sizes: List[int] = [5, 3],
+ channels: int = 32,
+ downsample_scales: List[int] = [3, 3, 3, 3, 1],
+ max_downsample_channels: int = 1024,
+ bias: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ use_spectral_norm: bool = False,
+ ):
+ """Initialize HiFiGANPeriodDiscriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ period (int): Period.
+ kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
+ layer.
+ channels (int): Number of initial channels.
+ downsample_scales (List[int]): List of downsampling scales.
+ max_downsample_channels (int): Number of maximum downsampling channels.
+ use_additional_convs (bool): Whether to use additional conv layers in
+ residual blocks.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+ use_spectral_norm (bool): Whether to use spectral norm.
+ If set to true, it will be applied to all of the conv layers.
+
+ """
+ super().__init__()
+ assert len(kernel_sizes) == 2
+ assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
+ assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
+
+ self.period = period
+ self.convs = torch.nn.ModuleList()
+ in_chs = in_channels
+ out_chs = channels
+ for downsample_scale in downsample_scales:
+ self.convs += [
+ torch.nn.Sequential(
+ torch.nn.Conv2d(
+ in_chs,
+ out_chs,
+ (kernel_sizes[0], 1),
+ (downsample_scale, 1),
+ padding=((kernel_sizes[0] - 1) // 2, 0),
+ ),
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ )
+ ]
+ in_chs = out_chs
+ # NOTE(kan-bayashi): Use downsample_scale + 1?
+ out_chs = min(out_chs * 4, max_downsample_channels)
+ self.output_conv = torch.nn.Conv2d(
+ out_chs,
+ out_channels,
+ (kernel_sizes[1] - 1, 1),
+ 1,
+ padding=((kernel_sizes[1] - 1) // 2, 0),
+ )
+
+ if use_weight_norm and use_spectral_norm:
+ raise ValueError("Either use use_weight_norm or use_spectral_norm.")
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # apply spectral norm
+ if use_spectral_norm:
+ self.apply_spectral_norm()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, in_channels, T).
+
+ Returns:
+ list: List of each layer's tensors.
+
+ """
+ # transform 1d to 2d -> (B, C, T/P, P)
+ b, c, t = x.shape
+ if t % self.period != 0:
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t += n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ # forward conv
+ outs = []
+ for layer in self.convs:
+ x = layer(x)
+ outs += [x]
+ x = self.output_conv(x)
+ x = torch.flatten(x, 1, -1)
+ outs += [x]
+
+ return outs
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def apply_spectral_norm(self):
+ """Apply spectral normalization module from all of the layers."""
+
+ def _apply_spectral_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.spectral_norm(m)
+ logging.debug(f"Spectral norm is applied to {m}.")
+
+ self.apply(_apply_spectral_norm)
+
+
+class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFiGAN multi-period discriminator module."""
+
+ def __init__(
+ self,
+ periods: List[int] = [2, 3, 5, 7, 11],
+ discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ ):
+ """Initialize HiFiGANMultiPeriodDiscriminator module.
+
+ Args:
+ periods (List[int]): List of periods.
+ discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
+ discriminator module. The period parameter will be overwritten.
+
+ """
+ super().__init__()
+ self.discriminators = torch.nn.ModuleList()
+ for period in periods:
+ params = copy.deepcopy(discriminator_params)
+ params["period"] = period
+ self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each
+ layer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+
+ return outs
+
+
+class HiFiGANScaleDiscriminator(torch.nn.Module):
+ """HiFi-GAN scale discriminator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_sizes: List[int] = [15, 41, 5, 3],
+ channels: int = 128,
+ max_downsample_channels: int = 1024,
+ max_groups: int = 16,
+ bias: int = True,
+ downsample_scales: List[int] = [2, 2, 4, 4, 1],
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ use_spectral_norm: bool = False,
+ ):
+ """Initilize HiFiGAN scale discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_sizes (List[int]): List of four kernel sizes. The first will be used
+ for the first conv layer, and the second is for downsampling part, and
+ the remaining two are for the last two output layers.
+ channels (int): Initial number of channels for conv layer.
+ max_downsample_channels (int): Maximum number of channels for downsampling
+ layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ downsample_scales (List[int]): List of downsampling scales.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+ use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
+ will be applied to all of the conv layers.
+
+ """
+ super().__init__()
+ self.layers = torch.nn.ModuleList()
+
+ # check kernel size is valid
+ assert len(kernel_sizes) == 4
+ for ks in kernel_sizes:
+ assert ks % 2 == 1
+
+ # add first layer
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels,
+ channels,
+ # NOTE(kan-bayashi): Use always the same kernel size
+ kernel_sizes[0],
+ bias=bias,
+ padding=(kernel_sizes[0] - 1) // 2,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+
+ # add downsample layers
+ in_chs = channels
+ out_chs = channels
+ # NOTE(kan-bayashi): Remove hard coding?
+ groups = 4
+ for downsample_scale in downsample_scales:
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_sizes[1],
+ stride=downsample_scale,
+ padding=(kernel_sizes[1] - 1) // 2,
+ groups=groups,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ )
+ ]
+ in_chs = out_chs
+ # NOTE(kan-bayashi): Remove hard coding?
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ # NOTE(kan-bayashi): Remove hard coding?
+ groups = min(groups * 4, max_groups)
+
+ # add final layers
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_sizes[2],
+ stride=1,
+ padding=(kernel_sizes[2] - 1) // 2,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+ self.layers += [
+ torch.nn.Conv1d(
+ out_chs,
+ out_channels,
+ kernel_size=kernel_sizes[3],
+ stride=1,
+ padding=(kernel_sizes[3] - 1) // 2,
+ bias=bias,
+ ),
+ ]
+
+ if use_weight_norm and use_spectral_norm:
+ raise ValueError("Either use use_weight_norm or use_spectral_norm.")
+
+ # apply weight norm
+ self.use_weight_norm = use_weight_norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # apply spectral norm
+ self.use_spectral_norm = use_spectral_norm
+ if use_spectral_norm:
+ self.apply_spectral_norm()
+
+ # backward compatibility
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[Tensor]: List of output tensors of each layer.
+
+ """
+ outs = []
+ for f in self.layers:
+ x = f(x)
+ outs += [x]
+
+ return outs
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def apply_spectral_norm(self):
+ """Apply spectral normalization module from all of the layers."""
+
+ def _apply_spectral_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.utils.spectral_norm(m)
+ logging.debug(f"Spectral norm is applied to {m}.")
+
+ self.apply(_apply_spectral_norm)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def remove_spectral_norm(self):
+ """Remove spectral normalization module from all of the layers."""
+
+ def _remove_spectral_norm(m):
+ try:
+ logging.debug(f"Spectral norm is removed from {m}.")
+ torch.nn.utils.remove_spectral_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_spectral_norm)
+
+ def _load_state_dict_pre_hook(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ """Fix the compatibility of weight / spectral normalization issue.
+
+ Some pretrained models are trained with configs that use weight / spectral
+ normalization, but actually, the norm is not applied. This causes the mismatch
+ of the parameters with configs. To solve this issue, when parameter mismatch
+ happens in loading pretrained model, we remove the norm from the current model.
+
+ See also:
+ - https://github.com/espnet/espnet/pull/5240
+ - https://github.com/espnet/espnet/pull/5249
+ - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
+
+ """
+ current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
+ if self.use_weight_norm and any(
+ [k.endswith("weight") for k in current_module_keys]
+ ):
+ logging.warning(
+ "It seems weight norm is not applied in the pretrained model but the"
+ " current model uses it. To keep the compatibility, we remove the norm"
+ " from the current model. This may cause unexpected behavior due to the"
+ " parameter mismatch in finetuning. To avoid this issue, please change"
+ " the following parameters in config to false:\n"
+ " - discriminator_params.follow_official_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_weight_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
+ "\n"
+ "See also:\n"
+ " - https://github.com/espnet/espnet/pull/5240\n"
+ " - https://github.com/espnet/espnet/pull/5249"
+ )
+ self.remove_weight_norm()
+ self.use_weight_norm = False
+ for k in current_module_keys:
+ if k.endswith("weight_g") or k.endswith("weight_v"):
+ del state_dict[k]
+
+ if self.use_spectral_norm and any(
+ [k.endswith("weight") for k in current_module_keys]
+ ):
+ logging.warning(
+ "It seems spectral norm is not applied in the pretrained model but the"
+ " current model uses it. To keep the compatibility, we remove the norm"
+ " from the current model. This may cause unexpected behavior due to the"
+ " parameter mismatch in finetuning. To avoid this issue, please change"
+ " the following parameters in config to false:\n"
+ " - discriminator_params.follow_official_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_weight_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
+ "\n"
+ "See also:\n"
+ " - https://github.com/espnet/espnet/pull/5240\n"
+ " - https://github.com/espnet/espnet/pull/5249"
+ )
+ self.remove_spectral_norm()
+ self.use_spectral_norm = False
+ for k in current_module_keys:
+ if (
+ k.endswith("weight_u")
+ or k.endswith("weight_v")
+ or k.endswith("weight_orig")
+ ):
+ del state_dict[k]
+
+
+class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale discriminator module."""
+
+ def __init__(
+ self,
+ scales: int = 3,
+ downsample_pooling: str = "AvgPool1d",
+ # follow the official implementation setting
+ downsample_pooling_params: Dict[str, Any] = {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ },
+ follow_official_norm: bool = False,
+ ):
+ """Initilize HiFiGAN multi-scale discriminator module.
+
+ Args:
+ scales (int): Number of multi-scales.
+ downsample_pooling (str): Pooling module name for downsampling of the
+ inputs.
+ downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
+ module.
+ discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
+ discriminator module.
+ follow_official_norm (bool): Whether to follow the norm setting of the
+ official implementaion. The first discriminator uses spectral norm
+ and the other discriminators use weight norm.
+
+ """
+ super().__init__()
+ self.discriminators = torch.nn.ModuleList()
+
+ # add discriminators
+ for i in range(scales):
+ params = copy.deepcopy(discriminator_params)
+ if follow_official_norm:
+ if i == 0:
+ params["use_weight_norm"] = False
+ params["use_spectral_norm"] = True
+ else:
+ params["use_weight_norm"] = True
+ params["use_spectral_norm"] = False
+ self.discriminators += [HiFiGANScaleDiscriminator(**params)]
+ self.pooling = None
+ if scales > 1:
+ self.pooling = getattr(torch.nn, downsample_pooling)(
+ **downsample_pooling_params
+ )
+
+ def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[List[torch.Tensor]]: List of list of each discriminator outputs,
+ which consists of eachlayer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+ if self.pooling is not None:
+ x = self.pooling(x)
+
+ return outs
+
+
+class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale + multi-period discriminator module."""
+
+ def __init__(
+ self,
+ # Multi-scale discriminator related
+ scales: int = 3,
+ scale_downsample_pooling: str = "AvgPool1d",
+ scale_downsample_pooling_params: Dict[str, Any] = {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ scale_discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ },
+ follow_official_norm: bool = True,
+ # Multi-period discriminator related
+ periods: List[int] = [2, 3, 5, 7, 11],
+ period_discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ ):
+ """Initilize HiFiGAN multi-scale + multi-period discriminator module.
+
+ Args:
+ scales (int): Number of multi-scales.
+ scale_downsample_pooling (str): Pooling module name for downsampling of the
+ inputs.
+ scale_downsample_pooling_params (dict): Parameters for the above pooling
+ module.
+ scale_discriminator_params (dict): Parameters for hifi-gan scale
+ discriminator module.
+ follow_official_norm (bool): Whether to follow the norm setting of the
+ official implementaion. The first discriminator uses spectral norm and
+ the other discriminators use weight norm.
+ periods (list): List of periods.
+ period_discriminator_params (dict): Parameters for hifi-gan period
+ discriminator module. The period parameter will be overwritten.
+
+ """
+ super().__init__()
+ self.msd = HiFiGANMultiScaleDiscriminator(
+ scales=scales,
+ downsample_pooling=scale_downsample_pooling,
+ downsample_pooling_params=scale_downsample_pooling_params,
+ discriminator_params=scale_discriminator_params,
+ follow_official_norm=follow_official_norm,
+ )
+ self.mpd = HiFiGANMultiPeriodDiscriminator(
+ periods=periods,
+ discriminator_params=period_discriminator_params,
+ )
+
+ def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[List[Tensor]]: List of list of each discriminator outputs,
+ which consists of each layer output tensors. Multi scale and
+ multi period ones are concatenated.
+
+ """
+ msd_outs = self.msd(x)
+ mpd_outs = self.mpd(x)
+ return msd_outs + mpd_outs
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
new file mode 100755
index 000000000..91a35e360
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -0,0 +1,233 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script performs model inference on test set.
+
+Usage:
+./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir ./vits/exp \
+ --max-duration 500
+"""
+
+
+import argparse
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import List
+
+import k2
+import torch
+import torch.nn as nn
+import torchaudio
+
+from train import get_model, get_params
+from tokenizer import Tokenizer
+
+from icefall.checkpoint import load_checkpoint
+from icefall.utils import AttributeDict, setup_logger
+from tts_datamodule import LJSpeechTtsDataModule
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def infer_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+) -> None:
+ """Decode dataset.
+ The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ tokenizer:
+ Used to convert text to phonemes.
+ """
+ # Background worker save audios to disk.
+ def _save_worker(
+ batch_size: int,
+ cut_ids: List[str],
+ audio: torch.Tensor,
+ audio_pred: torch.Tensor,
+ audio_lens: List[int],
+ audio_lens_pred: List[int],
+ ):
+ for i in range(batch_size):
+ torchaudio.save(
+ str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
+ audio[i:i + 1, :audio_lens[i]],
+ sample_rate=params.sampling_rate,
+ )
+ torchaudio.save(
+ str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
+ audio_pred[i:i + 1, :audio_lens_pred[i]],
+ sample_rate=params.sampling_rate,
+ )
+
+ device = next(model.parameters()).device
+ num_cuts = 0
+ log_interval = 5
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ futures = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ for batch_idx, batch in enumerate(dl):
+ batch_size = len(batch["tokens"])
+
+ tokens = batch["tokens"]
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ audio = batch["audio"]
+ audio_lens = batch["audio_lens"].tolist()
+ cut_ids = [cut.id for cut in batch["cut"]]
+
+ audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
+ audio_pred = audio_pred.detach().cpu()
+ # convert to samples
+ audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+
+ futures.append(
+ executor.submit(
+ _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
+ )
+ )
+
+ num_cuts += batch_size
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ # return results
+ for f in futures:
+ f.result()
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LJSpeechTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.suffix = f"epoch-{params.epoch}"
+
+ params.res_dir = params.exp_dir / "infer" / params.suffix
+ params.save_wav_dir = params.res_dir / "wav"
+ params.save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
+ logging.info("Infer started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model.to(device)
+ model.eval()
+
+ num_param_g = sum([p.numel() for p in model.generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ ljspeech = LJSpeechTtsDataModule(args)
+
+ test_cuts = ljspeech.test_cuts()
+ test_dl = ljspeech.test_dataloaders(test_cuts)
+
+ infer_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ )
+
+ logging.info(f"Wav files are saved to {params.save_wav_dir}")
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py
new file mode 100644
index 000000000..21aaad6e7
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/loss.py
@@ -0,0 +1,336 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HiFiGAN-related loss modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+from typing import List, Tuple, Union
+
+import torch
+import torch.distributions as D
+import torch.nn.functional as F
+
+from lhotse.features.kaldi import Wav2LogFilterBank
+
+
+class GeneratorAdversarialLoss(torch.nn.Module):
+ """Generator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize GeneratorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.criterion = self._mse_loss
+ else:
+ self.criterion = self._hinge_loss
+
+ def forward(
+ self,
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> torch.Tensor:
+ """Calcualate generator adversarial loss.
+
+ Args:
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs..
+
+ Returns:
+ Tensor: Generator adversarial loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ adv_loss = 0.0
+ for i, outputs_ in enumerate(outputs):
+ if isinstance(outputs_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_ = outputs_[-1]
+ adv_loss += self.criterion(outputs_)
+ if self.average_by_discriminators:
+ adv_loss /= i + 1
+ else:
+ adv_loss = self.criterion(outputs)
+
+ return adv_loss
+
+ def _mse_loss(self, x):
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _hinge_loss(self, x):
+ return -x.mean()
+
+
+class DiscriminatorAdversarialLoss(torch.nn.Module):
+ """Discriminator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize DiscriminatorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.fake_criterion = self._mse_fake_loss
+ self.real_criterion = self._mse_real_loss
+ else:
+ self.fake_criterion = self._hinge_fake_loss
+ self.real_criterion = self._hinge_real_loss
+
+ def forward(
+ self,
+ outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calcualate discriminator adversarial loss.
+
+ Args:
+ outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from generator.
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from groundtruth.
+
+ Returns:
+ Tensor: Discriminator real loss value.
+ Tensor: Discriminator fake loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ real_loss = 0.0
+ fake_loss = 0.0
+ for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
+ if isinstance(outputs_hat_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_hat_ = outputs_hat_[-1]
+ outputs_ = outputs_[-1]
+ real_loss += self.real_criterion(outputs_)
+ fake_loss += self.fake_criterion(outputs_hat_)
+ if self.average_by_discriminators:
+ fake_loss /= i + 1
+ real_loss /= i + 1
+ else:
+ real_loss = self.real_criterion(outputs)
+ fake_loss = self.fake_criterion(outputs_hat)
+
+ return real_loss, fake_loss
+
+ def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_zeros(x.size()))
+
+ def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
+
+ def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
+
+
+class FeatureMatchLoss(torch.nn.Module):
+ """Feature matching loss module."""
+
+ def __init__(
+ self,
+ average_by_layers: bool = True,
+ average_by_discriminators: bool = True,
+ include_final_outputs: bool = False,
+ ):
+ """Initialize FeatureMatchLoss module.
+
+ Args:
+ average_by_layers (bool): Whether to average the loss by the number
+ of layers.
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ include_final_outputs (bool): Whether to include the final output of
+ each discriminator for loss calculation.
+
+ """
+ super().__init__()
+ self.average_by_layers = average_by_layers
+ self.average_by_discriminators = average_by_discriminators
+ self.include_final_outputs = include_final_outputs
+
+ def forward(
+ self,
+ feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ ) -> torch.Tensor:
+ """Calculate feature matching loss.
+
+ Args:
+ feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from generator's outputs.
+ feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from groundtruth..
+
+ Returns:
+ Tensor: Feature matching loss value.
+
+ """
+ feat_match_loss = 0.0
+ for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
+ feat_match_loss_ = 0.0
+ if not self.include_final_outputs:
+ feats_hat_ = feats_hat_[:-1]
+ feats_ = feats_[:-1]
+ for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
+ feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
+ if self.average_by_layers:
+ feat_match_loss_ /= j + 1
+ feat_match_loss += feat_match_loss_
+ if self.average_by_discriminators:
+ feat_match_loss /= i + 1
+
+ return feat_match_loss
+
+
+class MelSpectrogramLoss(torch.nn.Module):
+ """Mel-spectrogram loss."""
+
+ def __init__(
+ self,
+ sampling_rate: int = 22050,
+ frame_length: int = 1024, # in samples
+ frame_shift: int = 256, # in samples
+ n_mels: int = 80,
+ use_fft_mag: bool = True,
+ ):
+ super().__init__()
+ self.wav_to_mel = Wav2LogFilterBank(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length / sampling_rate, # in second
+ frame_shift=frame_shift / sampling_rate, # in second
+ use_fft_mag=use_fft_mag,
+ num_filters=n_mels,
+ )
+
+ def forward(
+ self,
+ y_hat: torch.Tensor,
+ y: torch.Tensor,
+ return_mel: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
+ """Calculate Mel-spectrogram loss.
+
+ Args:
+ y_hat (Tensor): Generated waveform tensor (B, 1, T).
+ y (Tensor): Groundtruth waveform tensor (B, 1, T).
+ spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
+ (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
+ waveform.
+
+ Returns:
+ Tensor: Mel-spectrogram loss value.
+
+ """
+ mel_hat = self.wav_to_mel(y_hat.squeeze(1))
+ mel = self.wav_to_mel(y.squeeze(1))
+ mel_loss = F.l1_loss(mel_hat, mel)
+
+ if return_mel:
+ return mel_loss, (mel_hat, mel)
+
+ return mel_loss
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
+
+"""VITS-related loss modules.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+
+class KLDivergenceLoss(torch.nn.Module):
+ """KL divergence loss."""
+
+ def forward(
+ self,
+ z_p: torch.Tensor,
+ logs_q: torch.Tensor,
+ m_p: torch.Tensor,
+ logs_p: torch.Tensor,
+ z_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate KL divergence loss.
+
+ Args:
+ z_p (Tensor): Flow hidden representation (B, H, T_feats).
+ logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
+ m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
+ logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
+ z_mask (Tensor): Mask tensor (B, 1, T_feats).
+
+ Returns:
+ Tensor: KL divergence loss.
+
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ loss = kl / torch.sum(z_mask)
+
+ return loss
+
+
+class KLDivergenceLossWithoutFlow(torch.nn.Module):
+ """KL divergence loss without flow."""
+
+ def forward(
+ self,
+ m_q: torch.Tensor,
+ logs_q: torch.Tensor,
+ m_p: torch.Tensor,
+ logs_p: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate KL divergence loss without flow.
+
+ Args:
+ m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
+ logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
+ m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
+ logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
+ """
+ posterior_norm = D.Normal(m_q, torch.exp(logs_q))
+ prior_norm = D.Normal(m_p, torch.exp(logs_p))
+ loss = D.kl_divergence(posterior_norm, prior_norm).mean()
+ return loss
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
new file mode 100644
index 000000000..2b35654f5
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
@@ -0,0 +1,81 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
+
+"""Maximum path calculation module.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import warnings
+
+import numpy as np
+import torch
+from numba import njit, prange
+
+try:
+ from .core import maximum_path_c
+
+ is_cython_avalable = True
+except ImportError:
+ is_cython_avalable = False
+ warnings.warn(
+ "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
+ "If you want to use the cython version, please build it as follows: "
+ "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
+ )
+
+
+def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
+ """Calculate maximum path.
+
+ Args:
+ neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
+ attn_mask (Tensor): Attention mask (B, T_feats, T_text).
+
+ Returns:
+ Tensor: Maximum path tensor (B, T_feats, T_text).
+
+ """
+ device, dtype = neg_x_ent.device, neg_x_ent.dtype
+ neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
+ path = np.zeros(neg_x_ent.shape, dtype=np.int32)
+ t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
+ t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
+ if is_cython_avalable:
+ maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
+ else:
+ maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
+
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
+
+
+@njit
+def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
+ """Calculate a single maximum path with numba."""
+ index = t_x - 1
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y - 1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.0
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y - 1, x - 1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
+ index = index - 1
+
+
+@njit(parallel=True)
+def maximum_path_numba(paths, values, t_ys, t_xs):
+ """Calculate batch maximum path with numba."""
+ for i in prange(paths.shape[0]):
+ maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx
new file mode 100644
index 000000000..c02c2d02e
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx
@@ -0,0 +1,51 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
+
+"""Maximum path calculation module with cython optimization.
+
+This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
+
+"""
+
+cimport cython
+
+from cython.parallel import prange
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
+ cdef int x
+ cdef int y
+ cdef float v_prev
+ cdef float v_cur
+ cdef float tmp
+ cdef int index = t_x - 1
+
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y - 1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.0
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y - 1, x - 1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
+ index = index - 1
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
+ cdef int b = paths.shape[0]
+ cdef int i
+ for i in prange(b, nogil=True):
+ maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py
new file mode 100644
index 000000000..33d75e176
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py
@@ -0,0 +1,31 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
+"""Setup cython code."""
+
+from Cython.Build import cythonize
+from setuptools import Extension, setup
+from setuptools.command.build_ext import build_ext as _build_ext
+
+
+class build_ext(_build_ext):
+ """Overwrite build_ext."""
+
+ def finalize_options(self):
+ """Prevent numpy from thinking it is still in its setup process."""
+ _build_ext.finalize_options(self)
+ __builtins__.__NUMPY_SETUP__ = False
+ import numpy
+
+ self.include_dirs.append(numpy.get_include())
+
+
+exts = [
+ Extension(
+ name="core",
+ sources=["core.pyx"],
+ )
+]
+setup(
+ name="monotonic_align",
+ ext_modules=cythonize(exts, language_level=3),
+ cmdclass={"build_ext": build_ext},
+)
diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py
new file mode 100644
index 000000000..6b8a5be52
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/posterior_encoder.py
@@ -0,0 +1,117 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Posterior encoder module in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+from typing import Optional, Tuple
+
+import torch
+
+from icefall.utils import make_pad_mask
+from wavenet import WaveNet, Conv1d
+
+
+class PosteriorEncoder(torch.nn.Module):
+ """Posterior encoder module in VITS.
+
+ This is a module of posterior encoder described in `Conditional Variational
+ Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 513,
+ out_channels: int = 192,
+ hidden_channels: int = 192,
+ kernel_size: int = 5,
+ layers: int = 16,
+ stacks: int = 1,
+ base_dilation: int = 1,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ bias: bool = True,
+ use_weight_norm: bool = True,
+ ):
+ """Initilialize PosteriorEncoder module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size in WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of repeat stacking of WaveNet.
+ base_dilation (int): Base dilation factor.
+ global_channels (int): Number of global conditioning channels.
+ dropout_rate (float): Dropout rate.
+ bias (bool): Whether to use bias parameters in conv.
+ use_weight_norm (bool): Whether to apply weight norm.
+
+ """
+ super().__init__()
+
+ # define modules
+ self.input_conv = Conv1d(in_channels, hidden_channels, 1)
+ self.encoder = WaveNet(
+ in_channels=-1,
+ out_channels=-1,
+ kernel_size=kernel_size,
+ layers=layers,
+ stacks=stacks,
+ base_dilation=base_dilation,
+ residual_channels=hidden_channels,
+ aux_channels=-1,
+ gate_channels=hidden_channels * 2,
+ skip_channels=hidden_channels,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ use_weight_norm=use_weight_norm,
+ use_first_conv=False,
+ use_last_conv=False,
+ scale_residual=False,
+ scale_skip_connect=True,
+ )
+ self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(
+ self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T_feats).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
+ Tensor: Projected mean tensor (B, out_channels, T_feats).
+ Tensor: Projected scale tensor (B, out_channels, T_feats).
+ Tensor: Mask tensor for input tensor (B, 1, T_feats).
+
+ """
+ x_mask = (
+ (~make_pad_mask(x_lengths))
+ .unsqueeze(1)
+ .to(
+ dtype=x.dtype,
+ device=x.device,
+ )
+ )
+ x = self.input_conv(x) * x_mask
+ x = self.encoder(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+
+ return z, m, logs, x_mask
diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py
new file mode 100644
index 000000000..2d6807cb7
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/residual_coupling.py
@@ -0,0 +1,229 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Residual affine coupling modules in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from flow import FlipFlow
+from wavenet import WaveNet
+
+
+class ResidualAffineCouplingBlock(torch.nn.Module):
+ """Residual affine coupling block module.
+
+ This is a module of residual affine coupling block, which used as "Flow" in
+ `Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 192,
+ hidden_channels: int = 192,
+ flows: int = 4,
+ kernel_size: int = 5,
+ base_dilation: int = 1,
+ layers: int = 4,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ use_weight_norm: bool = True,
+ bias: bool = True,
+ use_only_mean: bool = True,
+ ):
+ """Initilize ResidualAffineCouplingBlock module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ flows (int): Number of flows.
+ kernel_size (int): Kernel size for WaveNet.
+ base_dilation (int): Base dilation factor for WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of stacks of WaveNet.
+ global_channels (int): Number of global channels.
+ dropout_rate (float): Dropout rate.
+ use_weight_norm (bool): Whether to use weight normalization in WaveNet.
+ bias (bool): Whether to use bias paramters in WaveNet.
+ use_only_mean (bool): Whether to estimate only mean.
+
+ """
+ super().__init__()
+
+ self.flows = torch.nn.ModuleList()
+ for i in range(flows):
+ self.flows += [
+ ResidualAffineCouplingLayer(
+ in_channels=in_channels,
+ hidden_channels=hidden_channels,
+ kernel_size=kernel_size,
+ base_dilation=base_dilation,
+ layers=layers,
+ stacks=1,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ use_weight_norm=use_weight_norm,
+ bias=bias,
+ use_only_mean=use_only_mean,
+ )
+ ]
+ self.flows += [FlipFlow()]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, in_channels, T).
+
+ """
+ if not inverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, inverse=inverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, inverse=inverse)
+ return x
+
+
+class ResidualAffineCouplingLayer(torch.nn.Module):
+ """Residual affine coupling layer."""
+
+ def __init__(
+ self,
+ in_channels: int = 192,
+ hidden_channels: int = 192,
+ kernel_size: int = 5,
+ base_dilation: int = 1,
+ layers: int = 5,
+ stacks: int = 1,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ use_weight_norm: bool = True,
+ bias: bool = True,
+ use_only_mean: bool = True,
+ ):
+ """Initialzie ResidualAffineCouplingLayer module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size for WaveNet.
+ base_dilation (int): Base dilation factor for WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of stacks of WaveNet.
+ global_channels (int): Number of global channels.
+ dropout_rate (float): Dropout rate.
+ use_weight_norm (bool): Whether to use weight normalization in WaveNet.
+ bias (bool): Whether to use bias paramters in WaveNet.
+ use_only_mean (bool): Whether to estimate only mean.
+
+ """
+ assert in_channels % 2 == 0, "in_channels should be divisible by 2"
+ super().__init__()
+ self.half_channels = in_channels // 2
+ self.use_only_mean = use_only_mean
+
+ # define modules
+ self.input_conv = torch.nn.Conv1d(
+ self.half_channels,
+ hidden_channels,
+ 1,
+ )
+ self.encoder = WaveNet(
+ in_channels=-1,
+ out_channels=-1,
+ kernel_size=kernel_size,
+ layers=layers,
+ stacks=stacks,
+ base_dilation=base_dilation,
+ residual_channels=hidden_channels,
+ aux_channels=-1,
+ gate_channels=hidden_channels * 2,
+ skip_channels=hidden_channels,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ use_weight_norm=use_weight_norm,
+ use_first_conv=False,
+ use_last_conv=False,
+ scale_residual=False,
+ scale_skip_connect=True,
+ )
+ if use_only_mean:
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels,
+ 1,
+ )
+ else:
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels * 2,
+ 1,
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, in_channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ xa, xb = x.split(x.size(1) // 2, dim=1)
+ h = self.input_conv(xa) * x_mask
+ h = self.encoder(h, x_mask, g=g)
+ stats = self.proj(h) * x_mask
+ if not self.use_only_mean:
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not inverse:
+ xb = m + xb * torch.exp(logs) * x_mask
+ x = torch.cat([xa, xb], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ xb = (xb - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([xa, xb], 1)
+ return x
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
new file mode 100755
index 000000000..8acca7c02
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script is used to test the exported onnx model by vits/export-onnx.py
+
+Use the onnx model to generate a wav:
+./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+"""
+
+
+import argparse
+import logging
+import onnxruntime as ort
+import torch
+import torchaudio
+
+from tokenizer import Tokenizer
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the onnx model.",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(self, model_filename: str):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
+
+ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ tokens:
+ A 1-D tensor of shape (1, T)
+ Returns:
+ A tensor of shape (1, T')
+ """
+ noise_scale = torch.tensor([0.667], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
+ alpha = torch.tensor([1.0], dtype=torch.float32)
+
+ out = self.model.run(
+ [
+ self.model.get_outputs()[0].name,
+ ],
+ {
+ self.model.get_inputs()[0].name: tokens.numpy(),
+ self.model.get_inputs()[1].name: tokens_lens.numpy(),
+ self.model.get_inputs()[2].name: noise_scale.numpy(),
+ self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
+ self.model.get_inputs()[4].name: alpha.numpy(),
+ },
+ )[0]
+ return torch.from_numpy(out)
+
+
+def main():
+ args = get_parser().parse_args()
+
+ tokenizer = Tokenizer(args.tokens)
+
+ logging.info("About to create onnx model")
+ model = OnnxModel(args.model_filename)
+
+ text = "I went there to see the land, the people and how their system works, end quote."
+ tokens = tokenizer.texts_to_token_ids([text])
+ tokens = torch.tensor(tokens) # (1, T)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
+ audio = model(tokens, tokens_lens) # (1, T')
+
+ torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
+ logging.info("Saved to test_onnx.wav")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py
new file mode 100644
index 000000000..9f337e45b
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/text_encoder.py
@@ -0,0 +1,662 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Text encoder module in VITS.
+
+This code is based on
+ - https://github.com/jaywalnut310/vits
+ - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
+ - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
+"""
+
+import copy
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+
+from icefall.utils import is_jit_tracing, make_pad_mask
+
+
+class TextEncoder(torch.nn.Module):
+ """Text encoder module in VITS.
+
+ This is a module of text encoder described in `Conditional Variational Autoencoder
+ with Adversarial Learning for End-to-End Text-to-Speech`.
+ """
+
+ def __init__(
+ self,
+ vocabs: int,
+ d_model: int = 192,
+ num_heads: int = 2,
+ dim_feedforward: int = 768,
+ cnn_module_kernel: int = 5,
+ num_layers: int = 6,
+ dropout: float = 0.1,
+ ):
+ """Initialize TextEncoder module.
+
+ Args:
+ vocabs (int): Vocabulary size.
+ d_model (int): attention dimension
+ num_heads (int): number of attention heads
+ dim_feedforward (int): feedforward dimention
+ cnn_module_kernel (int): convolution kernel size
+ num_layers (int): number of encoder layers
+ dropout (float): dropout rate
+ """
+ super().__init__()
+ self.d_model = d_model
+
+ # define modules
+ self.emb = torch.nn.Embedding(vocabs, d_model)
+ torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
+
+ # We use conformer as text encoder
+ self.encoder = Transformer(
+ d_model=d_model,
+ num_heads=num_heads,
+ dim_feedforward=dim_feedforward,
+ cnn_module_kernel=cnn_module_kernel,
+ num_layers=num_layers,
+ dropout=dropout,
+ )
+
+ self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input index tensor (B, T_text).
+ x_lengths (Tensor): Length tensor (B,).
+
+ Returns:
+ Tensor: Encoded hidden representation (B, attention_dim, T_text).
+ Tensor: Projected mean tensor (B, attention_dim, T_text).
+ Tensor: Projected scale tensor (B, attention_dim, T_text).
+ Tensor: Mask tensor for input tensor (B, 1, T_text).
+
+ """
+ # (B, T_text, embed_dim)
+ x = self.emb(x) * math.sqrt(self.d_model)
+
+ assert x.size(1) == x_lengths.max().item()
+
+ # (B, T_text)
+ pad_mask = make_pad_mask(x_lengths)
+
+ # encoder assume the channel last (B, T_text, embed_dim)
+ x = self.encoder(x, key_padding_mask=pad_mask)
+
+ # convert the channel first (B, embed_dim, T_text)
+ x = x.transpose(1, 2)
+ non_pad_mask = (~pad_mask).unsqueeze(1)
+ stats = self.proj(x) * non_pad_mask
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+
+ return x, m, logs, non_pad_mask
+
+
+class Transformer(nn.Module):
+ """
+ Args:
+ d_model (int): attention dimension
+ num_heads (int): number of attention heads
+ dim_feedforward (int): feedforward dimention
+ cnn_module_kernel (int): convolution kernel size
+ num_layers (int): number of encoder layers
+ dropout (float): dropout rate
+ """
+
+ def __init__(
+ self,
+ d_model: int = 192,
+ num_heads: int = 2,
+ dim_feedforward: int = 768,
+ cnn_module_kernel: int = 5,
+ num_layers: int = 6,
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+
+ self.num_layers = num_layers
+ self.d_model = d_model
+
+ self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model=d_model,
+ num_heads=num_heads,
+ dim_feedforward=dim_feedforward,
+ cnn_module_kernel=cnn_module_kernel,
+ dropout=dropout,
+ )
+ self.encoder = TransformerEncoder(encoder_layer, num_layers)
+ self.after_norm = nn.LayerNorm(d_model)
+
+ def forward(
+ self, x: Tensor, key_padding_mask: Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ lengths:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ """
+ x, pos_emb = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ x = self.encoder(
+ x, pos_emb, key_padding_mask=key_padding_mask
+ ) # (T, N, C)
+
+ x = self.after_norm(x)
+
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ """
+ TransformerEncoderLayer is made up of self-attn and feedforward.
+
+ Args:
+ d_model: the number of expected features in the input.
+ num_heads: the number of heads in the multi-head attention models.
+ dim_feedforward: the dimension of the feed-forward network model.
+ dropout: the dropout value (default=0.1).
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ dim_feedforward: int,
+ cnn_module_kernel: int,
+ dropout: float = 0.1,
+ ) -> None:
+ super(TransformerEncoderLayer, self).__init__()
+
+ self.feed_forward_macaron = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
+
+ self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
+ self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
+ self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
+ self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
+ self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
+
+ self.ff_scale = 0.5
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the transformer encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
+ pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
+ key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
+ """
+ # macaron style feed-forward module
+ src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
+
+ # multi-head self-attention module
+ src_attn = self.self_attn(
+ self.norm_mha(src),
+ pos_emb=pos_emb,
+ key_padding_mask=key_padding_mask,
+ )
+ src = src + self.dropout(src_attn)
+
+ # convolution module
+ src = src + self.dropout(self.conv_module(self.norm_conv(src)))
+
+ # feed-forward module
+ src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
+
+ src = self.norm_final(src)
+
+ return src
+
+
+class TransformerEncoder(nn.Module):
+ r"""TransformerEncoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer class.
+ num_layers: the number of sub-encoder-layers in the encoder.
+ """
+
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
+ super().__init__()
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
+ pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
+ key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
+ """
+ output = src
+
+ for layer_index, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ key_padding_mask=key_padding_mask,
+ )
+
+ return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module.
+
+ See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+ Args:
+ d_model: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+ """Construct an PositionalEncoding object."""
+ super(RelPositionalEncoding, self).__init__()
+
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: Tensor) -> None:
+ """Reset the positional encodings."""
+ x_size = x.size(1)
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x_size * 2 - 1:
+ # Note: TorchScript doesn't implement operator== for torch.Device
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vector and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2
+ - x.size(1)
+ + 1 : self.pe.size(1) // 2 # noqa E203
+ + x.size(1),
+ ]
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+ r"""Multi-Head Attention layer with relative position encoding
+
+ See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ ) -> None:
+ super(RelPositionMultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+ # linear transformation for positional encoding.
+ self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ nn.init.xavier_uniform_(self.in_proj.weight)
+ nn.init.constant_(self.in_proj.bias, 0.0)
+ nn.init.constant_(self.out_proj.bias, 0.0)
+
+ nn.init.xavier_uniform_(self.pos_bias_u)
+ nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x: Tensor) -> Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x: Input tensor (batch, head, seq_len, 2*seq_len-1).
+
+ Returns:
+ Tensor: tensor of shape (batch, head, seq_len, seq_len)
+ """
+ (batch_size, num_heads, seq_len, n) = x.shape
+
+ if not is_jit_tracing():
+ assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
+
+ if is_jit_tracing():
+ rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+
+ x = x.reshape(-1, n)
+ x = torch.gather(x, dim=1, index=indexes)
+ x = x.reshape(batch_size, num_heads, seq_len, seq_len)
+ return x
+ else:
+ # Note: TorchScript requires explicit arg for stride()
+ batch_stride = x.stride(0)
+ head_stride = x.stride(1)
+ time_stride = x.stride(2)
+ n_stride = x.stride(3)
+ return x.as_strided(
+ (batch_size, num_heads, seq_len, seq_len),
+ (batch_stride, head_stride, time_stride - n_stride, n_stride),
+ storage_offset=n_stride * (seq_len - 1),
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Args:
+ x: Input tensor of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ Its shape is (batch_size, seq_len).
+
+ Outputs:
+ A tensor of shape (seq_len, batch_size, embed_dim).
+ """
+ seq_len, batch_size, _ = x.shape
+ scaling = float(self.head_dim) ** -0.5
+
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
+
+ q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
+ k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
+ v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+
+ q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
+
+ p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
+ # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
+ p = p.permute(0, 2, 3, 1)
+
+ # (batch_size, num_head, seq_len, head_dim)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
+
+ # compute matrix b and matrix d
+ matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
+ matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
+
+ # (batch_size, num_head, seq_len, seq_len)
+ attn_output_weights = (matrix_ac + matrix_bd) * scaling
+ attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, seq_len)
+ attn_output_weights = attn_output_weights.view(
+ batch_size, self.num_heads, seq_len, seq_len
+ )
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(
+ batch_size * self.num_heads, seq_len, seq_len
+ )
+
+ attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+ attn_output_weights = nn.functional.dropout(
+ attn_output_weights, p=self.dropout, training=self.training
+ )
+
+ # (batch_size * num_head, seq_len, head_dim)
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
+
+ attn_output = (
+ attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
+ )
+ # (seq_len, batch_size, embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ bias: bool = True,
+ ) -> None:
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+
+ padding = (kernel_size - 1) // 2
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.LayerNorm(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = Swish()
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channels, time)
+ x = nn.functional.glu(x, dim=1) # (batch, channels, time)
+
+ # 1D Depthwise Conv
+ if src_key_padding_mask is not None:
+ x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+ x = self.depthwise_conv(x)
+ # x is (batch, channels, time)
+ x = x.permute(0, 2, 1)
+ x = self.norm(x)
+ x = x.permute(0, 2, 1)
+
+ x = self.activation(x)
+
+ x = self.pointwise_conv2(x) # (batch, channel, time)
+
+ return x.permute(2, 0, 1)
+
+
+class Swish(nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+def _test_text_encoder():
+ vocabs = 500
+ d_model = 192
+ batch_size = 5
+ seq_len = 100
+
+ m = TextEncoder(vocabs=vocabs, d_model=d_model)
+ x, m, logs, mask = m(
+ x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
+ x_lengths=torch.full((batch_size,), seq_len),
+ )
+ print(x.shape, m.shape, logs.shape, mask.shape)
+
+
+if __name__ == "__main__":
+ _test_text_encoder()
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
new file mode 100644
index 000000000..0678b26fe
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -0,0 +1,106 @@
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from utils import intersperse
+
+
+class Tokenizer(object):
+ def __init__(self, tokens: str):
+ """
+ Args:
+ tokens: the file that maps tokens to ids
+ """
+ # Parse token file
+ self.token2id: Dict[str, int] = {}
+ with open(tokens, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split()
+ if len(info) == 1:
+ # case of space
+ token = " "
+ id = int(info[0])
+ else:
+ token, id = info[0], int(info[1])
+ self.token2id[token] = id
+
+ self.blank_id = self.token2id[""]
+ self.oov_id = self.token2id[""]
+ self.vocab_size = len(self.token2id)
+
+ self.g2p = g2p_en.G2p()
+
+ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for text in texts:
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ tokens = self.g2p(text)
+ token_ids = []
+ for t in tokens:
+ if t in self.token2id:
+ token_ids.append(self.token2id[t])
+ else:
+ token_ids.append(self.oov_id)
+
+ if intersperse_blank:
+ token_ids = intersperse(token_ids, self.blank_id)
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+ def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
+ """
+ Args:
+ tokens_list:
+ A list of token list, each corresponding to one utterance.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t in self.token2id:
+ token_ids.append(self.token2id[t])
+ else:
+ token_ids.append(self.oov_id)
+
+ if intersperse_blank:
+ token_ids = intersperse(token_ids, self.blank_id)
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
new file mode 100755
index 000000000..eb43a4cc9
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -0,0 +1,893 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import numpy as np
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch.optim import Optimizer
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+from tokenizer import Tokenizer
+from tts_datamodule import LJSpeechTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
+
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1000,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ parser.add_argument(
+ "--lr", type=float, default=2.0e-4, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=20,
+ help="""Save checkpoint after processing this number of epochs"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.cur_epoch % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
+ Since it will take around 1000 epochs, we suggest using a large
+ save_every_n to save disk space.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ # training params
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": -1, # 0
+ "log_interval": 50,
+ "valid_interval": 200,
+ "env_info": get_env_info(),
+ "sampling_rate": 22050,
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
+ "n_mels": 80,
+ "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
+ "lambda_mel": 45.0, # loss scaling coefficient for Mel loss
+ "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
+ "lambda_dur": 1.0, # loss scaling coefficient for duration loss
+ "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict, model: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(filename, model=model)
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ mel_loss_params = {
+ "n_mels": params.n_mels,
+ "frame_length": params.frame_length,
+ "frame_shift": params.frame_shift,
+ }
+ model = VITS(
+ vocab_size=params.vocab_size,
+ feature_dim=params.feature_dim,
+ sampling_rate=params.sampling_rate,
+ mel_loss_params=mel_loss_params,
+ lambda_adv=params.lambda_adv,
+ lambda_mel=params.lambda_mel,
+ lambda_feat_match=params.lambda_feat_match,
+ lambda_dur=params.lambda_dur,
+ lambda_kl=params.lambda_kl,
+ )
+ return model
+
+
+def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
+ """Parse batch data"""
+ audio = batch["audio"].to(device)
+ features = batch["features"].to(device)
+ audio_lens = batch["audio_lens"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ tokens = batch["tokens"]
+
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # a tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ return audio, audio_lens, features, features_lens, tokens, tokens_lens
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ optimizer_g: Optimizer,
+ optimizer_d: Optimizer,
+ scheduler_g: LRSchedulerType,
+ scheduler_d: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ tokenizer:
+ Used to convert text to phonemes.
+ optimizer_g:
+ The optimizer for generator.
+ optimizer_d:
+ The optimizer for discriminator.
+ scheduler_g:
+ The learning rate scheduler for generator, we call step() every epoch.
+ scheduler_d:
+ The learning rate scheduler for discriminator, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ params=params,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["tokens"])
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+
+ loss_info = MetricsTracker()
+ loss_info['samples'] = batch_size
+
+ try:
+ with autocast(enabled=params.use_fp16):
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+ # update discriminator
+ optimizer_d.zero_grad()
+ scaler.scale(loss_d).backward()
+ scaler.step(optimizer_d)
+
+ with autocast(enabled=params.use_fp16):
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ return_sample=params.batch_idx_train % params.log_interval == 0,
+ )
+ for k, v in stats_g.items():
+ if "returned_sample" not in k:
+ loss_info[k] = v * batch_size
+ # update generator
+ optimizer_g.zero_grad()
+ scaler.scale(loss_g).backward()
+ scaler.step(optimizer_g)
+ scaler.update()
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+ except: # noqa
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr_g = max(scheduler_g.get_last_lr())
+ cur_lr_d = max(scheduler_d.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate_g", cur_lr_g, params.batch_idx_train
+ )
+ tb_writer.add_scalar(
+ "train/learning_rate_d", cur_lr_d, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+ if "returned_sample" in stats_g:
+ speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
+ tb_writer.add_audio(
+ "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_audio(
+ "train/speech_", speech_, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_image(
+ "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
+ )
+ tb_writer.add_image(
+ "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
+ )
+
+ if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info, (speech_hat, speech) = compute_validation_loss(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
+ )
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+ rank: int = 0,
+) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
+ """Run the validation process."""
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+ returned_sample = None
+
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(valid_dl):
+ batch_size = len(batch["tokens"])
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+
+ loss_info = MetricsTracker()
+ loss_info['samples'] = batch_size
+
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ assert loss_d.requires_grad is False
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ )
+ assert loss_g.requires_grad is False
+ for k, v in stats_g.items():
+ loss_info[k] = v * batch_size
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+
+ # infer for first batch:
+ if batch_idx == 0 and rank == 0:
+ inner_model = model.module if isinstance(model, DDP) else model
+ audio_pred, _, duration = inner_model.inference(
+ text=tokens[0, :tokens_lens[0].item()]
+ )
+ audio_pred = audio_pred.data.cpu().numpy()
+ audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
+ audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
+ returned_sample = (audio_pred, audio_gt)
+
+ if world_size > 1:
+ tot_loss.reduce(device)
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss, returned_sample
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ tokenizer: Tokenizer,
+ optimizer_g: torch.optim.Optimizer,
+ optimizer_d: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+ try:
+ # for discriminator
+ with autocast(enabled=params.use_fp16):
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ optimizer_d.zero_grad()
+ loss_d.backward()
+ # for generator
+ with autocast(enabled=params.use_fp16):
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ )
+ optimizer_g.zero_grad()
+ loss_g.backward()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+ generator = model.generator
+ discriminator = model.discriminator
+
+ num_param_g = sum([p.numel() for p in generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer_g = torch.optim.AdamW(
+ generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+ optimizer_d = torch.optim.AdamW(
+ discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
+
+ if checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer_g" in checkpoints:
+ logging.info("Loading optimizer_g state dict")
+ optimizer_g.load_state_dict(checkpoints["optimizer_g"])
+ if "optimizer_d" in checkpoints:
+ logging.info("Loading optimizer_d state dict")
+ optimizer_d.load_state_dict(checkpoints["optimizer_d"])
+
+ # load state_dict for schedulers
+ if "scheduler_g" in checkpoints:
+ logging.info("Loading scheduler_g state dict")
+ scheduler_g.load_state_dict(checkpoints["scheduler_g"])
+ if "scheduler_d" in checkpoints:
+ logging.info("Loading scheduler_d state dict")
+ scheduler_d.load_state_dict(checkpoints["scheduler_d"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ ljspeech = LJSpeechTtsDataModule(args)
+
+ train_cuts = ljspeech.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ train_dl = ljspeech.train_dataloaders(train_cuts)
+
+ valid_cuts = ljspeech.valid_cuts()
+ valid_dl = ljspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ # step per epoch
+ scheduler_g.step()
+ scheduler_d.step()
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ LJSpeechTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py
new file mode 100644
index 000000000..c20d13130
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/transform.py
@@ -0,0 +1,218 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
+
+"""Flow-related transformation.
+
+This code is derived from https://github.com/bayesiains/nflows.
+
+"""
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+DEFAULT_MIN_BIN_WIDTH = 1e-3
+DEFAULT_MIN_BIN_HEIGHT = 1e-3
+DEFAULT_MIN_DERIVATIVE = 1e-3
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def piecewise_rational_quadratic_transform(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails=None,
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if tails is None:
+ spline_fn = rational_quadratic_spline
+ spline_kwargs = {}
+ else:
+ spline_fn = unconstrained_rational_quadratic_spline
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
+
+ outputs, logabsdet = spline_fn(
+ inputs=inputs,
+ unnormalized_widths=unnormalized_widths,
+ unnormalized_heights=unnormalized_heights,
+ unnormalized_derivatives=unnormalized_derivatives,
+ inverse=inverse,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ **spline_kwargs
+ )
+ return outputs, logabsdet
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def unconstrained_rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails="linear",
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
+ outside_interval_mask = ~inside_interval_mask
+
+ outputs = torch.zeros_like(inputs)
+ logabsdet = torch.zeros_like(inputs)
+
+ if tails == "linear":
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
+ constant = np.log(np.exp(1 - min_derivative) - 1)
+ unnormalized_derivatives[..., 0] = constant
+ unnormalized_derivatives[..., -1] = constant
+
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
+ logabsdet[outside_interval_mask] = 0
+ else:
+ raise RuntimeError("{} tails are not implemented.".format(tails))
+
+ (
+ outputs[inside_interval_mask],
+ logabsdet[inside_interval_mask],
+ ) = rational_quadratic_spline(
+ inputs=inputs[inside_interval_mask],
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
+ inverse=inverse,
+ left=-tail_bound,
+ right=tail_bound,
+ bottom=-tail_bound,
+ top=tail_bound,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ )
+
+ return outputs, logabsdet
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ left=0.0,
+ right=1.0,
+ bottom=0.0,
+ top=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if torch.min(inputs) < left or torch.max(inputs) > right:
+ raise ValueError("Input to a transform is not within its domain")
+
+ num_bins = unnormalized_widths.shape[-1]
+
+ if min_bin_width * num_bins > 1.0:
+ raise ValueError("Minimal bin width too large for the number of bins")
+ if min_bin_height * num_bins > 1.0:
+ raise ValueError("Minimal bin height too large for the number of bins")
+
+ widths = F.softmax(unnormalized_widths, dim=-1)
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
+ cumwidths = torch.cumsum(widths, dim=-1)
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
+ cumwidths = (right - left) * cumwidths + left
+ cumwidths[..., 0] = left
+ cumwidths[..., -1] = right
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
+
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
+
+ heights = F.softmax(unnormalized_heights, dim=-1)
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
+ cumheights = torch.cumsum(heights, dim=-1)
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
+ cumheights = (top - bottom) * cumheights + bottom
+ cumheights[..., 0] = bottom
+ cumheights[..., -1] = top
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
+
+ if inverse:
+ bin_idx = _searchsorted(cumheights, inputs)[..., None]
+ else:
+ bin_idx = _searchsorted(cumwidths, inputs)[..., None]
+
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
+
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
+ delta = heights / widths
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
+
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
+
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
+
+ if inverse:
+ a = (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ ) + input_heights * (input_delta - input_derivatives)
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ )
+ c = -input_delta * (inputs - input_cumheights)
+
+ discriminant = b.pow(2) - 4 * a * c
+ assert (discriminant >= 0).all()
+
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
+ outputs = root * input_bin_widths + input_cumwidths
+
+ theta_one_minus_theta = root * (1 - root)
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * root.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - root).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, -logabsdet
+ else:
+ theta = (inputs - input_cumwidths) / input_bin_widths
+ theta_one_minus_theta = theta * (1 - theta)
+
+ numerator = input_heights * (
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
+ )
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ outputs = input_cumheights + numerator / denominator
+
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * theta.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - theta).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, logabsdet
+
+
+def _searchsorted(bin_locations, inputs, eps=1e-6):
+ bin_locations[..., -1] += eps
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
new file mode 100644
index 000000000..0fcbb92c1
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -0,0 +1,325 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ SpeechSynthesisDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LJSpeechTtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/spectrogram"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get validation cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
+ )
diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py
new file mode 100644
index 000000000..2a3dae900
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/utils.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple, Union
+import collections
+import logging
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from lhotse.dataset.sampling.base import CutSampler
+from pathlib import Path
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
+def get_random_segments(
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ segment_size: int,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get random segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ x_lengths (Tensor): Length tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+ Tensor: Start index tensor (B,).
+
+ """
+ b, c, t = x.size()
+ max_start_idx = x_lengths - segment_size
+ max_start_idx[max_start_idx < 0] = 0
+ start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
+ dtype=torch.long,
+ )
+ segments = get_segments(x, start_idxs, segment_size)
+
+ return segments, start_idxs
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
+def get_segments(
+ x: torch.Tensor,
+ start_idxs: torch.Tensor,
+ segment_size: int,
+) -> torch.Tensor:
+ """Get segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ start_idxs (Tensor): Start index tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+
+ """
+ b, c, t = x.size()
+ segments = x.new_zeros(b, c, segment_size)
+ for i, start_idx in enumerate(start_idxs):
+ segments[i] = x[i, :, start_idx : start_idx + segment_size]
+ return segments
+
+
+# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
+def intersperse(sequence, item=0):
+ result = [item] * (len(sequence) * 2 + 1)
+ result[1::2] = sequence
+ return result
+
+
+# from https://github.com/jaywalnut310/vits/blob/main/utils.py
+MATPLOTLIB_FLAG = False
+
+
+def plot_feature(spectrogram):
+ global MATPLOTLIB_FLAG
+ if not MATPLOTLIB_FLAG:
+ import matplotlib
+ matplotlib.use("Agg")
+ MATPLOTLIB_FLAG = True
+ mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger.setLevel(logging.WARNING)
+ import matplotlib.pylab as plt
+ import numpy as np
+
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+ return data
+
+
+class MetricsTracker(collections.defaultdict):
+ def __init__(self):
+ # Passing the type 'int' to the base-class constructor
+ # makes undefined items default to int() which is zero.
+ # This class will play a role as metrics tracker.
+ # It can record many metrics, including but not limited to loss.
+ super(MetricsTracker, self).__init__(int)
+
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v
+ for k, v in other.items():
+ ans[k] = ans[k] + v
+ return ans
+
+ def __mul__(self, alpha: float) -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v * alpha
+ return ans
+
+ def __str__(self) -> str:
+ ans = ""
+ for k, v in self.norm_items():
+ norm_value = "%.4g" % v
+ ans += str(k) + "=" + str(norm_value) + ", "
+ samples = "%.2f" % self["samples"]
+ ans += "over " + str(samples) + " samples."
+ return ans
+
+ def norm_items(self) -> List[Tuple[str, float]]:
+ """
+ Returns a list of pairs, like:
+ [('loss_1', 0.1), ('loss_2', 0.07)]
+ """
+ samples = self["samples"] if "samples" in self else 1
+ ans = []
+ for k, v in self.items():
+ if k == "samples":
+ continue
+ norm_value = float(v) / samples
+ ans.append((k, norm_value))
+ return ans
+
+ def reduce(self, device):
+ """
+ Reduce using torch.distributed, which I believe ensures that
+ all processes get the total.
+ """
+ keys = sorted(self.keys())
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
+ for k, v in zip(keys, s.cpu().tolist()):
+ self[k] = v
+
+ def write_summary(
+ self,
+ tb_writer: SummaryWriter,
+ prefix: str,
+ batch_idx: int,
+ ) -> None:
+ """Add logging information to a TensorBoard writer.
+
+ Args:
+ tb_writer: a TensorBoard writer
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
+ or "train/current_"
+ batch_idx: The current batch index, used as the x-axis of the plot.
+ """
+ for k, v in self.norm_items():
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
+
+
+# checkpoint saving and loading
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def save_checkpoint(
+ filename: Path,
+ model: Union[nn.Module, DDP],
+ params: Optional[Dict[str, Any]] = None,
+ optimizer_g: Optional[Optimizer] = None,
+ optimizer_d: Optional[Optimizer] = None,
+ scheduler_g: Optional[LRSchedulerType] = None,
+ scheduler_d: Optional[LRSchedulerType] = None,
+ scaler: Optional[GradScaler] = None,
+ sampler: Optional[CutSampler] = None,
+ rank: int = 0,
+) -> None:
+ """Save training information to a file.
+
+ Args:
+ filename:
+ The checkpoint filename.
+ model:
+ The model to be saved. We only save its `state_dict()`.
+ model_avg:
+ The stored model averaged from the start of training.
+ params:
+ User defined parameters, e.g., epoch, loss.
+ optimizer_g:
+ The optimizer for generator used in the training.
+ Its `state_dict` will be saved.
+ optimizer_d:
+ The optimizer for discriminator used in the training.
+ Its `state_dict` will be saved.
+ scheduler_g:
+ The learning rate scheduler for generator used in the training.
+ Its `state_dict` will be saved.
+ scheduler_d:
+ The learning rate scheduler for discriminator used in the training.
+ Its `state_dict` will be saved.
+ scalar:
+ The GradScaler to be saved. We only save its `state_dict()`.
+ rank:
+ Used in DDP. We save checkpoint only for the node whose rank is 0.
+ Returns:
+ Return None.
+ """
+ if rank != 0:
+ return
+
+ logging.info(f"Saving checkpoint to {filename}")
+
+ if isinstance(model, DDP):
+ model = model.module
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
+ "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
+ "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
+ "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
+ "sampler": sampler.state_dict() if sampler is not None else None,
+ }
+
+ if params:
+ for k, v in params.items():
+ assert k not in checkpoint
+ checkpoint[k] = v
+
+ torch.save(checkpoint, filename)
diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py
new file mode 100644
index 000000000..d5e20a578
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/vits.py
@@ -0,0 +1,610 @@
+# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""VITS module for GAN-TTS task."""
+
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from hifigan import (
+ HiFiGANMultiPeriodDiscriminator,
+ HiFiGANMultiScaleDiscriminator,
+ HiFiGANMultiScaleMultiPeriodDiscriminator,
+ HiFiGANPeriodDiscriminator,
+ HiFiGANScaleDiscriminator,
+)
+from loss import (
+ DiscriminatorAdversarialLoss,
+ FeatureMatchLoss,
+ GeneratorAdversarialLoss,
+ KLDivergenceLoss,
+ MelSpectrogramLoss,
+)
+from utils import get_segments
+from generator import VITSGenerator
+
+
+AVAILABLE_GENERATERS = {
+ "vits_generator": VITSGenerator,
+}
+AVAILABLE_DISCRIMINATORS = {
+ "hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
+ "hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
+ "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
+ "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
+ "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
+}
+
+
+class VITS(nn.Module):
+ """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
+ """
+
+ def __init__(
+ self,
+ # generator related
+ vocab_size: int,
+ feature_dim: int = 513,
+ sampling_rate: int = 22050,
+ generator_type: str = "vits_generator",
+ generator_params: Dict[str, Any] = {
+ "hidden_channels": 192,
+ "spks": None,
+ "langs": None,
+ "spk_embed_dim": None,
+ "global_channels": -1,
+ "segment_size": 32,
+ "text_encoder_attention_heads": 2,
+ "text_encoder_ffn_expand": 4,
+ "text_encoder_cnn_module_kernel": 5,
+ "text_encoder_blocks": 6,
+ "text_encoder_dropout_rate": 0.1,
+ "decoder_kernel_size": 7,
+ "decoder_channels": 512,
+ "decoder_upsample_scales": [8, 8, 2, 2],
+ "decoder_upsample_kernel_sizes": [16, 16, 4, 4],
+ "decoder_resblock_kernel_sizes": [3, 7, 11],
+ "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "use_weight_norm_in_decoder": True,
+ "posterior_encoder_kernel_size": 5,
+ "posterior_encoder_layers": 16,
+ "posterior_encoder_stacks": 1,
+ "posterior_encoder_base_dilation": 1,
+ "posterior_encoder_dropout_rate": 0.0,
+ "use_weight_norm_in_posterior_encoder": True,
+ "flow_flows": 4,
+ "flow_kernel_size": 5,
+ "flow_base_dilation": 1,
+ "flow_layers": 4,
+ "flow_dropout_rate": 0.0,
+ "use_weight_norm_in_flow": True,
+ "use_only_mean_in_flow": True,
+ "stochastic_duration_predictor_kernel_size": 3,
+ "stochastic_duration_predictor_dropout_rate": 0.5,
+ "stochastic_duration_predictor_flows": 4,
+ "stochastic_duration_predictor_dds_conv_layers": 3,
+ },
+ # discriminator related
+ discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
+ discriminator_params: Dict[str, Any] = {
+ "scales": 1,
+ "scale_downsample_pooling": "AvgPool1d",
+ "scale_downsample_pooling_params": {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ "scale_discriminator_params": {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ "follow_official_norm": False,
+ "periods": [2, 3, 5, 7, 11],
+ "period_discriminator_params": {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ },
+ # loss related
+ generator_adv_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "loss_type": "mse",
+ },
+ discriminator_adv_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "loss_type": "mse",
+ },
+ feat_match_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "average_by_layers": False,
+ "include_final_outputs": True,
+ },
+ mel_loss_params: Dict[str, Any] = {
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "n_mels": 80,
+ },
+ lambda_adv: float = 1.0,
+ lambda_mel: float = 45.0,
+ lambda_feat_match: float = 2.0,
+ lambda_dur: float = 1.0,
+ lambda_kl: float = 1.0,
+ cache_generator_outputs: bool = True,
+ ):
+ """Initialize VITS module.
+
+ Args:
+ idim (int): Input vocabrary size.
+ odim (int): Acoustic feature dimension. The actual output channels will
+ be 1 since VITS is the end-to-end text-to-wave model but for the
+ compatibility odim is used to indicate the acoustic feature dimension.
+ sampling_rate (int): Sampling rate, not used for the training but it will
+ be referred in saving waveform during the inference.
+ generator_type (str): Generator type.
+ generator_params (Dict[str, Any]): Parameter dict for generator.
+ discriminator_type (str): Discriminator type.
+ discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
+ generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
+ adversarial loss.
+ discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
+ discriminator adversarial loss.
+ feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
+ mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
+ lambda_adv (float): Loss scaling coefficient for adversarial loss.
+ lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
+ lambda_feat_match (float): Loss scaling coefficient for feat match loss.
+ lambda_dur (float): Loss scaling coefficient for duration loss.
+ lambda_kl (float): Loss scaling coefficient for KL divergence loss.
+ cache_generator_outputs (bool): Whether to cache generator outputs.
+
+ """
+ super().__init__()
+
+ # define modules
+ generator_class = AVAILABLE_GENERATERS[generator_type]
+ if generator_type == "vits_generator":
+ # NOTE(kan-bayashi): Update parameters for the compatibility.
+ # The idim and odim is automatically decided from input data,
+ # where idim represents #vocabularies and odim represents
+ # the input acoustic feature dimension.
+ generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
+ self.generator = generator_class(
+ **generator_params,
+ )
+ discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
+ self.discriminator = discriminator_class(
+ **discriminator_params,
+ )
+ self.generator_adv_loss = GeneratorAdversarialLoss(
+ **generator_adv_loss_params,
+ )
+ self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
+ **discriminator_adv_loss_params,
+ )
+ self.feat_match_loss = FeatureMatchLoss(
+ **feat_match_loss_params,
+ )
+ mel_loss_params.update(sampling_rate=sampling_rate)
+ self.mel_loss = MelSpectrogramLoss(
+ **mel_loss_params,
+ )
+ self.kl_loss = KLDivergenceLoss()
+
+ # coefficients
+ self.lambda_adv = lambda_adv
+ self.lambda_mel = lambda_mel
+ self.lambda_kl = lambda_kl
+ self.lambda_feat_match = lambda_feat_match
+ self.lambda_dur = lambda_dur
+
+ # cache
+ self.cache_generator_outputs = cache_generator_outputs
+ self._cache = None
+
+ # store sampling rate for saving wav file
+ # (not used for the training)
+ self.sampling_rate = sampling_rate
+
+ # store parameters for test compatibility
+ self.spks = self.generator.spks
+ self.langs = self.generator.langs
+ self.spk_embed_dim = self.generator.spk_embed_dim
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ return_sample: bool = False,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ forward_generator: bool = True,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform generator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+ forward_generator (bool): Whether to forward generator.
+
+ Returns:
+ - loss (Tensor): Loss scalar tensor.
+ - stats (Dict[str, float]): Statistics to be monitored.
+ """
+ if forward_generator:
+ return self._forward_generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ speech=speech,
+ speech_lengths=speech_lengths,
+ return_sample=return_sample,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ return self._forward_discrminator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ speech=speech,
+ speech_lengths=speech_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+
+ def _forward_generator(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ return_sample: bool = False,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform generator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ * loss (Tensor): Loss scalar tensor.
+ * stats (Dict[str, float]): Statistics to be monitored.
+ """
+ # setup
+ feats = feats.transpose(1, 2)
+ speech = speech.unsqueeze(1)
+
+ # calculate generator outputs
+ reuse_cache = True
+ if not self.cache_generator_outputs or self._cache is None:
+ reuse_cache = False
+ outs = self.generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ outs = self._cache
+
+ # store cache
+ if self.training and self.cache_generator_outputs and not reuse_cache:
+ self._cache = outs
+
+ # parse outputs
+ speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
+ _, z_p, m_p, logs_p, _, logs_q = outs_
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * self.generator.upsample_factor,
+ segment_size=self.generator.segment_size * self.generator.upsample_factor,
+ )
+
+ # calculate discriminator outputs
+ p_hat = self.discriminator(speech_hat_)
+ with torch.no_grad():
+ # do not store discriminator gradient in generator turn
+ p = self.discriminator(speech_)
+
+ # calculate losses
+ with autocast(enabled=False):
+ if not return_sample:
+ mel_loss = self.mel_loss(speech_hat_, speech_)
+ else:
+ mel_loss, (mel_hat_, mel_) = self.mel_loss(
+ speech_hat_, speech_, return_mel=True
+ )
+ kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
+ dur_loss = torch.sum(dur_nll.float())
+ adv_loss = self.generator_adv_loss(p_hat)
+ feat_match_loss = self.feat_match_loss(p_hat, p)
+
+ mel_loss = mel_loss * self.lambda_mel
+ kl_loss = kl_loss * self.lambda_kl
+ dur_loss = dur_loss * self.lambda_dur
+ adv_loss = adv_loss * self.lambda_adv
+ feat_match_loss = feat_match_loss * self.lambda_feat_match
+ loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
+
+ stats = dict(
+ generator_loss=loss.item(),
+ generator_mel_loss=mel_loss.item(),
+ generator_kl_loss=kl_loss.item(),
+ generator_dur_loss=dur_loss.item(),
+ generator_adv_loss=adv_loss.item(),
+ generator_feat_match_loss=feat_match_loss.item(),
+ )
+
+ if return_sample:
+ stats["returned_sample"] = (
+ speech_hat_[0].data.cpu().numpy(),
+ speech_[0].data.cpu().numpy(),
+ mel_hat_[0].data.cpu().numpy(),
+ mel_[0].data.cpu().numpy(),
+ )
+
+ # reset cache
+ if reuse_cache or not self.training:
+ self._cache = None
+
+ return loss, stats
+
+ def _forward_discrminator(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform discriminator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ * loss (Tensor): Loss scalar tensor.
+ * stats (Dict[str, float]): Statistics to be monitored.
+ """
+ # setup
+ feats = feats.transpose(1, 2)
+ speech = speech.unsqueeze(1)
+
+ # calculate generator outputs
+ reuse_cache = True
+ if not self.cache_generator_outputs or self._cache is None:
+ reuse_cache = False
+ outs = self.generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ outs = self._cache
+
+ # store cache
+ if self.cache_generator_outputs and not reuse_cache:
+ self._cache = outs
+
+ # parse outputs
+ speech_hat_, _, _, start_idxs, *_ = outs
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * self.generator.upsample_factor,
+ segment_size=self.generator.segment_size * self.generator.upsample_factor,
+ )
+
+ # calculate discriminator outputs
+ p_hat = self.discriminator(speech_hat_.detach())
+ p = self.discriminator(speech_)
+
+ # calculate losses
+ with autocast(enabled=False):
+ real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
+ loss = real_loss + fake_loss
+
+ stats = dict(
+ discriminator_loss=loss.item(),
+ discriminator_real_loss=real_loss.item(),
+ discriminator_fake_loss=fake_loss.item(),
+ )
+
+ # reset cache
+ if reuse_cache or not self.training:
+ self._cache = None
+
+ return loss, stats
+
+ def inference(
+ self,
+ text: torch.Tensor,
+ feats: Optional[torch.Tensor] = None,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ durations: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference for single sample.
+
+ Args:
+ text (Tensor): Input text index tensor (T_text,).
+ feats (Tensor): Feature tensor (T_feats, aux_channels).
+ sids (Tensor): Speaker index tensor (1,).
+ spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
+ lids (Tensor): Language index tensor (1,).
+ durations (Tensor): Ground-truth duration tensor (T_text,).
+ noise_scale (float): Noise scale value for flow.
+ noise_scale_dur (float): Noise scale value for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length.
+ use_teacher_forcing (bool): Whether to use teacher forcing.
+
+ Returns:
+ * wav (Tensor): Generated waveform tensor (T_wav,).
+ * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
+ * duration (Tensor): Predicted duration tensor (T_text,).
+ """
+ # setup
+ text = text[None]
+ text_lengths = torch.tensor(
+ [text.size(1)],
+ dtype=torch.long,
+ device=text.device,
+ )
+ if sids is not None:
+ sids = sids.view(1)
+ if lids is not None:
+ lids = lids.view(1)
+ if durations is not None:
+ durations = durations.view(1, 1, -1)
+
+ # inference
+ if use_teacher_forcing:
+ assert feats is not None
+ feats = feats[None].transpose(1, 2)
+ feats_lengths = torch.tensor(
+ [feats.size(2)],
+ dtype=torch.long,
+ device=feats.device,
+ )
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ max_len=max_len,
+ use_teacher_forcing=use_teacher_forcing,
+ )
+ else:
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ dur=durations,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ max_len=max_len,
+ )
+ return wav.view(-1), att_w[0], dur[0]
+
+ def inference_batch(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ durations: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference for one batch.
+
+ Args:
+ text (Tensor): Input text index tensor (B, T_text).
+ text_lengths (Tensor): Input text index tensor (B,).
+ sids (Tensor): Speaker index tensor (B,).
+ noise_scale (float): Noise scale value for flow.
+ noise_scale_dur (float): Noise scale value for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length.
+
+ Returns:
+ * wav (Tensor): Generated waveform tensor (B, T_wav).
+ * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
+ * duration (Tensor): Predicted duration tensor (B, T_text).
+ """
+ # inference
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ sids=sids,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ max_len=max_len,
+ )
+ return wav, att_w, dur
diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py
new file mode 100644
index 000000000..fbe1be52b
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/wavenet.py
@@ -0,0 +1,349 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""WaveNet modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import math
+import logging
+
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+class WaveNet(torch.nn.Module):
+ """WaveNet with global conditioning."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_size: int = 3,
+ layers: int = 30,
+ stacks: int = 3,
+ base_dilation: int = 2,
+ residual_channels: int = 64,
+ aux_channels: int = -1,
+ gate_channels: int = 128,
+ skip_channels: int = 64,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ bias: bool = True,
+ use_weight_norm: bool = True,
+ use_first_conv: bool = False,
+ use_last_conv: bool = False,
+ scale_residual: bool = False,
+ scale_skip_connect: bool = False,
+ ):
+ """Initialize WaveNet module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Kernel size of dilated convolution.
+ layers (int): Number of residual block layers.
+ stacks (int): Number of stacks i.e., dilation cycles.
+ base_dilation (int): Base dilation factor.
+ residual_channels (int): Number of channels in residual conv.
+ gate_channels (int): Number of channels in gated conv.
+ skip_channels (int): Number of channels in skip conv.
+ aux_channels (int): Number of channels for local conditioning feature.
+ global_channels (int): Number of channels for global conditioning feature.
+ dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
+ bias (bool): Whether to use bias parameter in conv layer.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+ use_first_conv (bool): Whether to use the first conv layers.
+ use_last_conv (bool): Whether to use the last conv layers.
+ scale_residual (bool): Whether to scale the residual outputs.
+ scale_skip_connect (bool): Whether to scale the skip connection outputs.
+
+ """
+ super().__init__()
+ self.layers = layers
+ self.stacks = stacks
+ self.kernel_size = kernel_size
+ self.base_dilation = base_dilation
+ self.use_first_conv = use_first_conv
+ self.use_last_conv = use_last_conv
+ self.scale_skip_connect = scale_skip_connect
+
+ # check the number of layers and stacks
+ assert layers % stacks == 0
+ layers_per_stack = layers // stacks
+
+ # define first convolution
+ if self.use_first_conv:
+ self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
+
+ # define residual blocks
+ self.conv_layers = torch.nn.ModuleList()
+ for layer in range(layers):
+ dilation = base_dilation ** (layer % layers_per_stack)
+ conv = ResidualBlock(
+ kernel_size=kernel_size,
+ residual_channels=residual_channels,
+ gate_channels=gate_channels,
+ skip_channels=skip_channels,
+ aux_channels=aux_channels,
+ global_channels=global_channels,
+ dilation=dilation,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ scale_residual=scale_residual,
+ )
+ self.conv_layers += [conv]
+
+ # define output layers
+ if self.use_last_conv:
+ self.last_conv = torch.nn.Sequential(
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, out_channels, bias=True),
+ )
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: Optional[torch.Tensor] = None,
+ c: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
+ (B, residual_channels, T).
+ x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
+ c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
+ g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T) if use_last_conv else
+ (B, residual_channels, T).
+
+ """
+ # encode to hidden representation
+ if self.use_first_conv:
+ x = self.first_conv(x)
+
+ # residual block
+ skips = 0.0
+ for f in self.conv_layers:
+ x, h = f(x, x_mask=x_mask, c=c, g=g)
+ skips = skips + h
+ x = skips
+ if self.scale_skip_connect:
+ x = x * math.sqrt(1.0 / len(self.conv_layers))
+
+ # apply final layers
+ if self.use_last_conv:
+ x = self.last_conv(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m: torch.nn.Module):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ @staticmethod
+ def _get_receptive_field_size(
+ layers: int,
+ stacks: int,
+ kernel_size: int,
+ base_dilation: int,
+ ) -> int:
+ assert layers % stacks == 0
+ layers_per_cycle = layers // stacks
+ dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
+ return (kernel_size - 1) * sum(dilations) + 1
+
+ @property
+ def receptive_field_size(self) -> int:
+ """Return receptive field size."""
+ return self._get_receptive_field_size(
+ self.layers, self.stacks, self.kernel_size, self.base_dilation
+ )
+
+
+class Conv1d(torch.nn.Conv1d):
+ """Conv1d module with customized initialization."""
+
+ def __init__(self, *args, **kwargs):
+ """Initialize Conv1d module."""
+ super().__init__(*args, **kwargs)
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
+ if self.bias is not None:
+ torch.nn.init.constant_(self.bias, 0.0)
+
+
+class Conv1d1x1(Conv1d):
+ """1x1 Conv1d with customized initialization."""
+
+ def __init__(self, in_channels: int, out_channels: int, bias: bool):
+ """Initialize 1x1 Conv1d module."""
+ super().__init__(
+ in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
+ )
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in WaveNet."""
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ residual_channels: int = 64,
+ gate_channels: int = 128,
+ skip_channels: int = 64,
+ aux_channels: int = 80,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ dilation: int = 1,
+ bias: bool = True,
+ scale_residual: bool = False,
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ residual_channels (int): Number of channels for residual connection.
+ skip_channels (int): Number of channels for skip connection.
+ aux_channels (int): Number of local conditioning channels.
+ dropout (float): Dropout probability.
+ dilation (int): Dilation factor.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ scale_residual (bool): Whether to scale the residual outputs.
+
+ """
+ super().__init__()
+ self.dropout_rate = dropout_rate
+ self.residual_channels = residual_channels
+ self.skip_channels = skip_channels
+ self.scale_residual = scale_residual
+
+ # check
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+ assert gate_channels % 2 == 0
+
+ # dilation conv
+ padding = (kernel_size - 1) // 2 * dilation
+ self.conv = Conv1d(
+ residual_channels,
+ gate_channels,
+ kernel_size,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ # local conditioning
+ if aux_channels > 0:
+ self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_aux = None
+
+ # global conditioning
+ if global_channels > 0:
+ self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_glo = None
+
+ # conv output is split into two groups
+ gate_out_channels = gate_channels // 2
+
+ # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
+ # (integrate res 1x1 + skip 1x1 convs)
+ self.conv1x1_out = Conv1d1x1(
+ gate_out_channels, residual_channels + skip_channels, bias=bias
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: Optional[torch.Tensor] = None,
+ c: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, residual_channels, T).
+ x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
+ c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor for residual connection (B, residual_channels, T).
+ Tensor: Output tensor for skip connection (B, skip_channels, T).
+
+ """
+ residual = x
+ x = F.dropout(x, p=self.dropout_rate, training=self.training)
+ x = self.conv(x)
+
+ # split into two part for gated activation
+ splitdim = 1
+ xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+ # local conditioning
+ if c is not None:
+ c = self.conv1x1_aux(c)
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ca, xb + cb
+
+ # global conditioning
+ if g is not None:
+ g = self.conv1x1_glo(g)
+ ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ga, xb + gb
+
+ x = torch.tanh(xa) * torch.sigmoid(xb)
+
+ # residual + skip 1x1 conv
+ x = self.conv1x1_out(x)
+ if x_mask is not None:
+ x = x * x_mask
+
+ # split integrated conv results
+ x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
+
+ # for residual connection
+ x = x + residual
+ if self.scale_residual:
+ x = x * math.sqrt(0.5)
+
+ return x, s
diff --git a/pyproject.toml b/pyproject.toml
index c40143fb9..435256416 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,4 +14,5 @@ exclude = '''
| icefall\/diagnostics\.py
| icefall\/profiler\.py
| egs\/librispeech\/ASR\/zipformer
+ | egs\/ljspeech\/TTS\/vits
'''
From f08af2fa2217e226394b6f03442952d104bd984e Mon Sep 17 00:00:00 2001
From: LoganLiu66 <2319277867@qq.com>
Date: Mon, 4 Dec 2023 22:29:42 +0800
Subject: [PATCH 009/123] fix initial states (#1398)
Co-authored-by: liujiawang02
---
.../pruned_transducer_stateless7_streaming/decode_stream.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
index 0d7e86fcf..2c4b144fc 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
@@ -82,12 +82,12 @@ class DecodeStream(object):
self.pad_length = 7
if params.decoding_method == "greedy_search":
- self.hyp = [params.blank_id] * params.context_size
+ self.hyp = [-1] * (params.context_size - 1) + [params.blank_id]
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
- ys=[params.blank_id] * params.context_size,
+ ys=[-1] * (params.context_size - 1) + [params.blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
From 735fb9a73dea7d27e95056add6598ae7a282d6f9 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Wed, 6 Dec 2023 09:59:19 +0800
Subject: [PATCH 010/123] A TTS recipe VITS on VCTK dataset (#1380)
* init
* isort formatted
* minor updates
* Create shared
* Update prepare_tokens_vctk.py
* Update prepare_tokens_vctk.py
* Update prepare_tokens_vctk.py
* Update prepare.sh
* updated
* Update train.py
* Update train.py
* Update tts_datamodule.py
* Update train.py
* Update train.py
* Update train.py
* Update train.py
* Update train.py
* Update train.py
* fixed formatting issue
* Update infer.py
* removed redundant files
* Create monotonic_align
* removed redundant files
* created symlinks
* Update prepare.sh
* minor adjustments
* Create requirements_tts.txt
* Update requirements_tts.txt
added version constraints
* Update infer.py
* Update infer.py
* Update infer.py
* updated docs
* Update export-onnx.py
* Update export-onnx.py
* Update test_onnx.py
* updated requirements.txt
* Update test_onnx.py
* Update test_onnx.py
* docs updated
* docs fixed
* minor updates
---
docs/source/recipes/TTS/index.rst | 1 +
docs/source/recipes/TTS/ljspeech/vits.rst | 12 +-
docs/source/recipes/TTS/vctk/vits.rst | 125 +++
egs/ljspeech/TTS/prepare.sh | 16 +-
egs/ljspeech/TTS/vits/duration_predictor.py | 1 -
egs/ljspeech/TTS/vits/export-onnx.py | 8 +-
egs/ljspeech/TTS/vits/flow.py | 1 -
egs/ljspeech/TTS/vits/generator.py | 5 +-
egs/ljspeech/TTS/vits/infer.py | 29 +-
egs/ljspeech/TTS/vits/loss.py | 1 -
egs/ljspeech/TTS/vits/posterior_encoder.py | 2 +-
egs/ljspeech/TTS/vits/residual_coupling.py | 1 -
egs/ljspeech/TTS/vits/test_onnx.py | 2 +-
egs/ljspeech/TTS/vits/text_encoder.py | 48 +-
egs/ljspeech/TTS/vits/tokenizer.py | 4 +-
egs/ljspeech/TTS/vits/train.py | 93 +-
egs/ljspeech/TTS/vits/tts_datamodule.py | 2 +-
egs/ljspeech/TTS/vits/utils.py | 14 +-
egs/ljspeech/TTS/vits/vits.py | 9 +-
egs/ljspeech/TTS/vits/wavenet.py | 3 +-
.../TTS/local/compute_spectrogram_vctk.py | 107 ++
.../TTS/local/display_manifest_statistics.py | 83 ++
egs/vctk/TTS/local/prepare_token_file.py | 104 ++
egs/vctk/TTS/local/prepare_tokens_vctk.py | 61 +
egs/vctk/TTS/local/validate_manifest.py | 70 ++
egs/vctk/TTS/prepare.sh | 131 +++
egs/vctk/TTS/shared | 1 +
egs/vctk/TTS/vits/duration_predictor.py | 1 +
egs/vctk/TTS/vits/export-onnx.py | 284 +++++
egs/vctk/TTS/vits/flow.py | 1 +
egs/vctk/TTS/vits/generator.py | 1 +
egs/vctk/TTS/vits/hifigan.py | 1 +
egs/vctk/TTS/vits/infer.py | 272 +++++
egs/vctk/TTS/vits/loss.py | 1 +
egs/vctk/TTS/vits/monotonic_align | 1 +
egs/vctk/TTS/vits/posterior_encoder.py | 1 +
egs/vctk/TTS/vits/residual_coupling.py | 1 +
egs/vctk/TTS/vits/test_onnx.py | 138 +++
egs/vctk/TTS/vits/text_encoder.py | 1 +
egs/vctk/TTS/vits/tokenizer.py | 1 +
egs/vctk/TTS/vits/train.py | 1000 +++++++++++++++++
egs/vctk/TTS/vits/transform.py | 1 +
egs/vctk/TTS/vits/tts_datamodule.py | 338 ++++++
egs/vctk/TTS/vits/utils.py | 1 +
egs/vctk/TTS/vits/vits.py | 1 +
egs/vctk/TTS/vits/wavenet.py | 1 +
requirements-tts.txt | 6 +
requirements.txt | 2 +
48 files changed, 2904 insertions(+), 84 deletions(-)
create mode 100644 docs/source/recipes/TTS/vctk/vits.rst
create mode 100755 egs/vctk/TTS/local/compute_spectrogram_vctk.py
create mode 100755 egs/vctk/TTS/local/display_manifest_statistics.py
create mode 100755 egs/vctk/TTS/local/prepare_token_file.py
create mode 100755 egs/vctk/TTS/local/prepare_tokens_vctk.py
create mode 100755 egs/vctk/TTS/local/validate_manifest.py
create mode 100755 egs/vctk/TTS/prepare.sh
create mode 120000 egs/vctk/TTS/shared
create mode 120000 egs/vctk/TTS/vits/duration_predictor.py
create mode 100755 egs/vctk/TTS/vits/export-onnx.py
create mode 120000 egs/vctk/TTS/vits/flow.py
create mode 120000 egs/vctk/TTS/vits/generator.py
create mode 120000 egs/vctk/TTS/vits/hifigan.py
create mode 100755 egs/vctk/TTS/vits/infer.py
create mode 120000 egs/vctk/TTS/vits/loss.py
create mode 120000 egs/vctk/TTS/vits/monotonic_align
create mode 120000 egs/vctk/TTS/vits/posterior_encoder.py
create mode 120000 egs/vctk/TTS/vits/residual_coupling.py
create mode 100755 egs/vctk/TTS/vits/test_onnx.py
create mode 120000 egs/vctk/TTS/vits/text_encoder.py
create mode 120000 egs/vctk/TTS/vits/tokenizer.py
create mode 100755 egs/vctk/TTS/vits/train.py
create mode 120000 egs/vctk/TTS/vits/transform.py
create mode 100644 egs/vctk/TTS/vits/tts_datamodule.py
create mode 120000 egs/vctk/TTS/vits/utils.py
create mode 120000 egs/vctk/TTS/vits/vits.py
create mode 120000 egs/vctk/TTS/vits/wavenet.py
create mode 100644 requirements-tts.txt
diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst
index aa891c072..80d67a2f3 100644
--- a/docs/source/recipes/TTS/index.rst
+++ b/docs/source/recipes/TTS/index.rst
@@ -5,3 +5,4 @@ TTS
:maxdepth: 2
ljspeech/vits
+ vctk/vits
\ No newline at end of file
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
index 385fd3c70..d08aa0f47 100644
--- a/docs/source/recipes/TTS/ljspeech/vits.rst
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -4,6 +4,10 @@ VITS
This tutorial shows you how to train an VITS model
with the `LJSpeech `_ dataset.
+.. note::
+
+ TTS related recipes require packages in ``requirements-tts.txt``.
+
.. note::
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
@@ -27,6 +31,12 @@ To run stage 1 to stage 5, use
Build Monotonic Alignment Search
--------------------------------
+.. code-block:: bash
+
+ $ ./prepare.sh --stage -1 --stop_stage -1
+
+or
+
.. code-block:: bash
$ cd vits/monotonic_align
@@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir
$ ./vits/infer.py \
--epoch 1000 \
--exp-dir vits/exp \
- --tokens data/tokens.txt
+ --tokens data/tokens.txt \
--max-duration 500
.. note::
diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst
new file mode 100644
index 000000000..34024a5ea
--- /dev/null
+++ b/docs/source/recipes/TTS/vctk/vits.rst
@@ -0,0 +1,125 @@
+VITS
+===============
+
+This tutorial shows you how to train an VITS model
+with the `VCTK `_ dataset.
+
+.. note::
+
+ TTS related recipes require packages in ``requirements-tts.txt``.
+
+.. note::
+
+ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/vctk/TTS
+ $ ./prepare.sh
+
+To run stage 1 to stage 6, use
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 1 --stop_stage 6
+
+
+Build Monotonic Alignment Search
+--------------------------------
+
+To build the monotonic alignment search, use the following commands:
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage -1 --stop_stage -1
+
+or
+
+.. code-block:: bash
+
+ $ cd vits/monotonic_align
+ $ python setup.py build_ext --inplace
+ $ cd ../../
+
+
+Training
+--------
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ $ ./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1000 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 350
+
+.. note::
+
+ You can adjust the hyper-parameters to control the size of the VITS model and
+ the training configurations. For more details, please run ``./vits/train.py --help``.
+
+.. note::
+
+ The training can take a long time (usually a couple of days).
+
+Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
+
+
+Inference
+---------
+
+The inference part uses checkpoints saved by the training part, so you have to run the
+training part first. It will save the ground-truth and generated wavs to the directory
+``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0"
+ $ ./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt \
+ --max-duration 500
+
+.. note::
+
+ For more details, please run ``./vits/infer.py --help``.
+
+
+Export models
+-------------
+
+Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
+``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+
+.. code-block:: bash
+
+ $ ./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+You can test the exported ONNX model with:
+
+.. code-block:: bash
+
+ $ ./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following link:
+
+ - ``_
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
index 8ee40896e..ed0a07f5e 100755
--- a/egs/ljspeech/TTS/prepare.sh
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
-nj=1
-stage=-1
+stage=0
stop_stage=100
dl_dir=$PWD/download
@@ -25,6 +24,17 @@ log() {
log "dl_dir: $dl_dir"
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: build monotonic_align lib"
+ if [ ! -d vits/monotonic_align/build ]; then
+ cd vits/monotonic_align
+ python setup.py build_ext --inplace
+ cd ../../
+ else
+ log "monotonic_align lib already built"
+ fi
+fi
+
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
@@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--tokens data/tokens.txt
fi
fi
-
-
diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py
index c29a28479..1a8190014 100644
--- a/egs/ljspeech/TTS/vits/duration_predictor.py
+++ b/egs/ljspeech/TTS/vits/duration_predictor.py
@@ -14,7 +14,6 @@ from typing import Optional
import torch
import torch.nn.functional as F
-
from flow import (
ConvFlow,
DilatedDepthSeparableConv,
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index 154de4bf4..2068adeea 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -180,7 +180,13 @@ def export_model_onnx(
model_filename,
verbose=False,
opset_version=opset_version,
- input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
+ input_names=[
+ "tokens",
+ "tokens_lens",
+ "noise_scale",
+ "noise_scale_dur",
+ "alpha",
+ ],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py
index 206bd5e3e..2b84f6434 100644
--- a/egs/ljspeech/TTS/vits/flow.py
+++ b/egs/ljspeech/TTS/vits/flow.py
@@ -13,7 +13,6 @@ import math
from typing import Optional, Tuple, Union
import torch
-
from transform import piecewise_rational_quadratic_transform
diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py
index efb0e254c..66c8cedb1 100644
--- a/egs/ljspeech/TTS/vits/generator.py
+++ b/egs/ljspeech/TTS/vits/generator.py
@@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
-
-from icefall.utils import make_pad_mask
-
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
@@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
+from icefall.utils import make_pad_mask
+
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
index 91a35e360..cf0d20ae2 100755
--- a/egs/ljspeech/TTS/vits/infer.py
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -36,13 +36,12 @@ import k2
import torch
import torch.nn as nn
import torchaudio
-
-from train import get_model, get_params
from tokenizer import Tokenizer
+from train import get_model, get_params
+from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
-from tts_datamodule import LJSpeechTtsDataModule
def get_parser():
@@ -107,12 +106,12 @@ def infer_dataset(
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
- audio[i:i + 1, :audio_lens[i]],
+ audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
- audio_pred[i:i + 1, :audio_lens_pred[i]],
+ audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
@@ -144,14 +143,24 @@ def infer_dataset(
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
- audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
+ audio_pred, _, durations = model.inference_batch(
+ text=tokens, text_lengths=tokens_lens
+ )
audio_pred = audio_pred.detach().cpu()
# convert to samples
- audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+ audio_lens_pred = (
+ (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+ )
futures.append(
executor.submit(
- _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
+ _save_worker,
+ batch_size,
+ cut_ids,
+ audio,
+ audio_pred,
+ audio_lens,
+ audio_lens_pred,
)
)
@@ -160,7 +169,9 @@ def infer_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
- logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
# return results
for f in futures:
f.result()
diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py
index 21aaad6e7..2f4dc9bc0 100644
--- a/egs/ljspeech/TTS/vits/loss.py
+++ b/egs/ljspeech/TTS/vits/loss.py
@@ -14,7 +14,6 @@ from typing import List, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
-
from lhotse.features.kaldi import Wav2LogFilterBank
diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py
index 6b8a5be52..1104fb864 100644
--- a/egs/ljspeech/TTS/vits/posterior_encoder.py
+++ b/egs/ljspeech/TTS/vits/posterior_encoder.py
@@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple
import torch
+from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask
-from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):
diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py
index 2d6807cb7..f9a2a3786 100644
--- a/egs/ljspeech/TTS/vits/residual_coupling.py
+++ b/egs/ljspeech/TTS/vits/residual_coupling.py
@@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple, Union
import torch
-
from flow import FlipFlow
from wavenet import WaveNet
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
index 8acca7c02..686fee2a0 100755
--- a/egs/ljspeech/TTS/vits/test_onnx.py
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -28,10 +28,10 @@ Use the onnx model to generate a wav:
import argparse
import logging
+
import onnxruntime as ort
import torch
import torchaudio
-
from tokenizer import Tokenizer
diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py
index 9f337e45b..fcbae7103 100644
--- a/egs/ljspeech/TTS/vits/text_encoder.py
+++ b/egs/ljspeech/TTS/vits/text_encoder.py
@@ -169,9 +169,7 @@ class Transformer(nn.Module):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
- x = self.encoder(
- x, pos_emb, key_padding_mask=key_padding_mask
- ) # (T, N, C)
+ x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.after_norm(x)
@@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
- self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
+ self.self_attn = RelPositionMultiheadAttention(
+ d_model, num_heads, dropout=dropout
+ )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
@@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module):
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
- src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
+ src = src + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(self.norm_ff_macaron(src))
+ )
# multi-head self-attention module
src_attn = self.self_attn(
@@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
- v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+ v = (
+ v.contiguous()
+ .view(seq_len, batch_size * self.num_heads, self.head_dim)
+ .transpose(0, 1)
+ )
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
- p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
+ p = self.linear_pos(pos_emb).view(
+ pos_emb.size(0), -1, self.num_heads, self.head_dim
+ )
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
@@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
- matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
+ matrix_ac = torch.matmul(
+ q_with_bias_u, k
+ ) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
- matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
- matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
+ matrix_bd = torch.matmul(
+ q_with_bias_v, p
+ ) # (batch_size, num_head, seq_len, 2*seq_len-1)
+ matrix_bd = self.rel_shift(
+ matrix_bd
+ ) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
- attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
+ attn_output_weights = attn_output_weights.view(
+ batch_size * self.num_heads, seq_len, seq_len
+ )
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
@@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module):
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
- assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
+ assert attn_output.shape == (
+ batch_size * self.num_heads,
+ seq_len,
+ self.head_dim,
+ )
attn_output = (
- attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
+ attn_output.transpose(0, 1)
+ .contiguous()
+ .view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
index 0678b26fe..70f1240b4 100644
--- a/egs/ljspeech/TTS/vits/tokenizer.py
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -78,7 +78,9 @@ class Tokenizer(object):
return token_ids_list
- def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
+ def tokens_to_token_ids(
+ self, tokens_list: List[str], intersperse_blank: bool = True
+ ):
"""
Args:
tokens_list:
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
index eb43a4cc9..71c4224fa 100755
--- a/egs/ljspeech/TTS/vits/train.py
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -18,21 +18,25 @@
import argparse
import logging
-import numpy as np
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
+import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
-from torch.optim import Optimizer
+from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
+from tts_datamodule import LJSpeechTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
@@ -41,11 +45,6 @@ from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
-from tokenizer import Tokenizer
-from tts_datamodule import LJSpeechTtsDataModule
-from utils import MetricsTracker, plot_feature, save_checkpoint
-from vits import VITS
-
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@@ -385,11 +384,12 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
- audio, audio_lens, features, features_lens, tokens, tokens_lens = \
- prepare_input(batch, tokenizer, device)
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
+ batch, tokenizer, device
+ )
loss_info = MetricsTracker()
- loss_info['samples'] = batch_size
+ loss_info["samples"] = batch_size
try:
with autocast(enabled=params.use_fp16):
@@ -446,7 +446,9 @@ def train_one_epoch(
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
- if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
+ if cur_grad_scale < 8.0 or (
+ cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
+ ):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@@ -482,9 +484,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
- tot_loss.write_summary(
- tb_writer, "train/tot_", params.batch_idx_train
- )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
@@ -492,19 +492,34 @@ def train_one_epoch(
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
- "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
+ "train/speech_hat_",
+ speech_hat_,
+ params.batch_idx_train,
+ params.sampling_rate,
)
tb_writer.add_audio(
- "train/speech_", speech_, params.batch_idx_train, params.sampling_rate
+ "train/speech_",
+ speech_,
+ params.batch_idx_train,
+ params.sampling_rate,
)
tb_writer.add_image(
- "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
+ "train/mel_hat_",
+ plot_feature(mel_hat_),
+ params.batch_idx_train,
+ dataformats="HWC",
)
tb_writer.add_image(
- "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
+ "train/mel_",
+ plot_feature(mel_),
+ params.batch_idx_train,
+ dataformats="HWC",
)
- if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
+ if (
+ params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
@@ -523,10 +538,16 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
- "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
+ "train/valdi_speech_hat",
+ speech_hat,
+ params.batch_idx_train,
+ params.sampling_rate,
)
tb_writer.add_audio(
- "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
+ "train/valdi_speech",
+ speech,
+ params.batch_idx_train,
+ params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
@@ -555,11 +576,17 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
- audio, audio_lens, features, features_lens, tokens, tokens_lens = \
- prepare_input(batch, tokenizer, device)
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ ) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
- loss_info['samples'] = batch_size
+ loss_info["samples"] = batch_size
# forward discriminator
loss_d, stats_d = model(
@@ -596,12 +623,17 @@ def compute_validation_loss(
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
- text=tokens[0, :tokens_lens[0].item()]
+ text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
- audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
- assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
- audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
+ audio_len_pred = (
+ (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ )
+ assert audio_len_pred == len(audio_pred), (
+ audio_len_pred,
+ len(audio_pred),
+ )
+ audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
@@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
- audio, audio_lens, features, features_lens, tokens, tokens_lens = \
- prepare_input(batch, tokenizer, device)
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
+ batch, tokenizer, device
+ )
try:
# for discriminator
with autocast(enabled=params.use_fp16):
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
index 0fcbb92c1..81bb9ed13 100644
--- a/egs/ljspeech/TTS/vits/tts_datamodule.py
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
- SpeechSynthesisDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
+ SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py
index 2a3dae900..6a067f596 100644
--- a/egs/ljspeech/TTS/vits/utils.py
+++ b/egs/ljspeech/TTS/vits/utils.py
@@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
-from pathlib import Path
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@@ -97,23 +97,23 @@ def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
+
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
- interpolation='none')
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py
index d5e20a578..b4f0c21e6 100644
--- a/egs/ljspeech/TTS/vits/vits.py
+++ b/egs/ljspeech/TTS/vits/vits.py
@@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
-from torch.cuda.amp import autocast
-
+from generator import VITSGenerator
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
@@ -25,9 +24,8 @@ from loss import (
KLDivergenceLoss,
MelSpectrogramLoss,
)
+from torch.cuda.amp import autocast
from utils import get_segments
-from generator import VITSGenerator
-
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
@@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = {
class VITS(nn.Module):
- """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
- """
+ """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
def __init__(
self,
diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py
index fbe1be52b..5db461d5c 100644
--- a/egs/ljspeech/TTS/vits/wavenet.py
+++ b/egs/ljspeech/TTS/vits/wavenet.py
@@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
-import math
import logging
-
+import math
from typing import Optional, Tuple
import torch
diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py
new file mode 100755
index 000000000..440ac1245
--- /dev/null
+++ b/egs/vctk/TTS/local/compute_spectrogram_vctk.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the VCTK dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/spectrogram.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ LilcomChunkyWriter,
+ Spectrogram,
+ SpectrogramConfig,
+ load_manifest,
+)
+from lhotse.audio import RecordingSet
+from lhotse.supervision import SupervisionSet
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_spectrogram_vctk():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/spectrogram")
+ num_jobs = min(32, os.cpu_count())
+
+ sampling_rate = 22050
+ frame_length = 1024 / sampling_rate # (in second)
+ frame_shift = 256 / sampling_rate # (in second)
+ use_fft_mag = True
+
+ prefix = "vctk"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ recordings = load_manifest(
+ src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
+ ).resample(sampling_rate=sampling_rate)
+ supervisions = load_manifest(
+ src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
+ )
+
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ use_fft_mag=use_fft_mag,
+ )
+ extractor = Spectrogram(config)
+
+ with get_executor() as ex: # Initialize the executor only once.
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ return
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_spectrogram_vctk()
diff --git a/egs/vctk/TTS/local/display_manifest_statistics.py b/egs/vctk/TTS/local/display_manifest_statistics.py
new file mode 100755
index 000000000..0472e2cea
--- /dev/null
+++ b/egs/vctk/TTS/local/display_manifest_statistics.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in vits/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ path = "./data/spectrogram/vctk_cuts_all.jsonl.gz"
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Cut statistics:
+╒═══════════════════════════╤══════════╕
+│ Cuts count: │ 43873 │
+├───────────────────────────┼──────────┤
+│ Total duration (hh:mm:ss) │ 41:02:18 │
+├───────────────────────────┼──────────┤
+│ mean │ 3.4 │
+├───────────────────────────┼──────────┤
+│ std │ 1.2 │
+├───────────────────────────┼──────────┤
+│ min │ 1.2 │
+├───────────────────────────┼──────────┤
+│ 25% │ 2.6 │
+├───────────────────────────┼──────────┤
+│ 50% │ 3.1 │
+├───────────────────────────┼──────────┤
+│ 75% │ 3.8 │
+├───────────────────────────┼──────────┤
+│ 99% │ 8.0 │
+├───────────────────────────┼──────────┤
+│ 99.5% │ 9.1 │
+├───────────────────────────┼──────────┤
+│ 99.9% │ 12.1 │
+├───────────────────────────┼──────────┤
+│ max │ 16.6 │
+├───────────────────────────┼──────────┤
+│ Recordings available: │ 43873 │
+├───────────────────────────┼──────────┤
+│ Features available: │ 43873 │
+├───────────────────────────┼──────────┤
+│ Supervisions available: │ 43873 │
+╘═══════════════════════════╧══════════╛
+SUPERVISION custom fields:
+Speech duration statistics:
+╒══════════════════════════════╤══════════╤══════════════════════╕
+│ Total speech duration │ 41:02:18 │ 100.00% of recording │
+├──────────────────────────────┼──────────┼──────────────────────┤
+│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │
+├──────────────────────────────┼──────────┼──────────────────────┤
+│ Total silence duration │ 00:00:01 │ 0.00% of recording │
+╘══════════════════════════════╧══════════╧══════════════════════╛
+"""
diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py
new file mode 100755
index 000000000..c6636c3ad
--- /dev/null
+++ b/egs/vctk/TTS/local/prepare_token_file.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict
+
+from lhotse import load_manifest
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-file",
+ type=Path,
+ default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
+ help="Path to the manifest file",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens.txt"),
+ help="Path to the tokens",
+ )
+
+ return parser.parse_args()
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_token2id(manifest_file: Path) -> Dict[str, int]:
+ """Return a dict that maps token to IDs."""
+ extra_tokens = [
+ "", # 0 for blank
+ "", # 1 for sos and eos symbols.
+ "", # 2 for OOV
+ ]
+ all_tokens = set()
+
+ cut_set = load_manifest(manifest_file)
+
+ for cut in cut_set:
+ # Each cut only contain one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ for t in cut.tokens:
+ all_tokens.add(t)
+
+ all_tokens = extra_tokens + list(all_tokens)
+
+ token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
+ return token2id
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ manifest_file = Path(args.manifest_file)
+ out_file = Path(args.tokens)
+
+ token2id = get_token2id(manifest_file)
+ write_mapping(out_file, token2id)
diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py
new file mode 100755
index 000000000..32e1c7dfa
--- /dev/null
+++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py
@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import logging
+from pathlib import Path
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from lhotse import CutSet, load_manifest
+from tqdm.auto import tqdm
+
+
+def prepare_tokens_vctk():
+ output_dir = Path("data/spectrogram")
+ prefix = "vctk"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+ g2p = g2p_en.G2p()
+
+ new_cuts = []
+ for cut in tqdm(cut_set):
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ text = cut.supervisions[0].text
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ cut.tokens = g2p(text)
+ new_cuts.append(cut)
+
+ new_cut_set = CutSet.from_cuts(new_cuts)
+ new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ prepare_tokens_vctk()
diff --git a/egs/vctk/TTS/local/validate_manifest.py b/egs/vctk/TTS/local/validate_manifest.py
new file mode 100755
index 000000000..cd466303e
--- /dev/null
+++ b/egs/vctk/TTS/local/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh
new file mode 100755
index 000000000..87150ad31
--- /dev/null
+++ b/egs/vctk/TTS/prepare.sh
@@ -0,0 +1,131 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=0
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: build monotonic_align lib"
+ if [ ! -d vits/monotonic_align/build ]; then
+ cd vits/monotonic_align
+ python setup.py build_ext --inplace
+ cd ../../
+ else
+ log "monotonic_align lib already built"
+ fi
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/VCTK,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/VCTK $dl_dir/VCTK
+ #
+ if [ ! -d $dl_dir/VCTK ]; then
+ lhotse download vctk $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare VCTK manifest"
+ # We assume that you have downloaded the VCTK corpus
+ # to $dl_dir/VCTK
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.vctk.done ]; then
+ lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests
+ touch data/manifests/.vctk.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute spectrogram for VCTK"
+ mkdir -p data/spectrogram
+ if [ ! -e data/spectrogram/.vctk.done ]; then
+ ./local/compute_spectrogram_vctk.py
+ touch data/spectrogram/.vctk.done
+ fi
+
+ if [ ! -e data/spectrogram/.vctk-validated.done ]; then
+ log "Validating data/fbank for VCTK"
+ ./local/validate_manifest.py \
+ data/spectrogram/vctk_cuts_all.jsonl.gz
+ touch data/spectrogram/.vctk-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare phoneme tokens for VCTK"
+ if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
+ ./local/prepare_tokens_vctk.py
+ mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
+ data/spectrogram/vctk_cuts_all.jsonl.gz
+ touch data/spectrogram/.vctk_with_token.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Split the VCTK cuts into train, valid and test sets"
+ if [ ! -e data/spectrogram/.vctk_split.done ]; then
+ lhotse subset --last 600 \
+ data/spectrogram/vctk_cuts_all.jsonl.gz \
+ data/spectrogram/vctk_cuts_validtest.jsonl.gz
+ lhotse subset --first 100 \
+ data/spectrogram/vctk_cuts_validtest.jsonl.gz \
+ data/spectrogram/vctk_cuts_valid.jsonl.gz
+ lhotse subset --last 500 \
+ data/spectrogram/vctk_cuts_validtest.jsonl.gz \
+ data/spectrogram/vctk_cuts_test.jsonl.gz
+
+ rm data/spectrogram/vctk_cuts_validtest.jsonl.gz
+
+ n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 ))
+ lhotse subset --first $n \
+ data/spectrogram/vctk_cuts_all.jsonl.gz \
+ data/spectrogram/vctk_cuts_train.jsonl.gz
+ touch data/spectrogram/.vctk_split.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate token file"
+ # We assume you have installed g2p_en and espnet_tts_frontend.
+ # If not, please install them with:
+ # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
+ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
+ if [ ! -e data/tokens.txt ]; then
+ ./local/prepare_token_file.py \
+ --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
+ --tokens data/tokens.txt
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Generate speakers file"
+ if [ ! -e data/speakers.txt ]; then
+ gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \
+ | jq '.speaker' | sed 's/"//g' \
+ | sort | uniq > data/speakers.txt
+ fi
+fi
diff --git a/egs/vctk/TTS/shared b/egs/vctk/TTS/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/vctk/TTS/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py
new file mode 120000
index 000000000..9972b476f
--- /dev/null
+++ b/egs/vctk/TTS/vits/duration_predictor.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/duration_predictor.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py
new file mode 100755
index 000000000..7c9664cc1
--- /dev/null
+++ b/egs/vctk/TTS/vits/export-onnx.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script exports a VITS model from PyTorch to ONNX.
+
+Export the model to ONNX:
+./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+It will generate two files inside vits/exp:
+ - vits-epoch-1000.onnx
+ - vits-epoch-1000.int8.onnx (quantizated model)
+
+See ./test_onnx.py for how to use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from tokenizer import Tokenizer
+from train import get_model, get_params
+
+from icefall.checkpoint import load_checkpoint
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--speakers",
+ type=Path,
+ default=Path("data/speakers.txt"),
+ help="Path to speakers.txt file.",
+ )
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxModel(nn.Module):
+ """A wrapper for VITS generator."""
+
+ def __init__(self, model: nn.Module):
+ """
+ Args:
+ model:
+ A VITS generator.
+ frame_shift:
+ The frame shift in samples.
+ """
+ super().__init__()
+ self.model = model
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ speaker: int = 20,
+ alpha: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of VITS.inference_batch
+
+ Args:
+ tokens:
+ Input text token indexes (1, T_text)
+ tokens_lens:
+ Number of tokens of shape (1,)
+ noise_scale (float):
+ Noise scale parameter for flow.
+ noise_scale_dur (float):
+ Noise scale parameter for duration predictor.
+ speaker (int):
+ Speaker ID.
+ alpha (float):
+ Alpha parameter to control the speed of generated speech.
+
+ Returns:
+ Return a tuple containing:
+ - audio, generated wavform tensor, (B, T_wav)
+ """
+ audio, _, _ = self.model.inference(
+ text=tokens,
+ text_lengths=tokens_lens,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ sids=speaker,
+ alpha=alpha,
+ )
+ return audio
+
+
+def export_model_onnx(
+ model: nn.Module,
+ model_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given generator model to ONNX format.
+ The exported model has one input:
+
+ - tokens, a tensor of shape (1, T_text); dtype is torch.int64
+
+ and it has one output:
+
+ - audio, a tensor of shape (1, T'); dtype is torch.float32
+
+ Args:
+ model:
+ The VITS generator.
+ model_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
+ noise_scale = torch.tensor([1], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([1], dtype=torch.float32)
+ alpha = torch.tensor([1], dtype=torch.float32)
+ speaker = torch.tensor([1], dtype=torch.int64)
+
+ torch.onnx.export(
+ model,
+ (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
+ model_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=[
+ "tokens",
+ "tokens_lens",
+ "noise_scale",
+ "noise_scale_dur",
+ "speaker",
+ "alpha",
+ ],
+ output_names=["audio"],
+ dynamic_axes={
+ "tokens": {0: "N", 1: "T"},
+ "tokens_lens": {0: "N"},
+ "audio": {0: "N", 1: "T"},
+ "speaker": {0: "N"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "VITS",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "VITS generator",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=model_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ with open(args.speakers) as f:
+ speaker_map = {line.strip(): i for i, line in enumerate(f)}
+ params.num_spks = len(speaker_map)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model = model.generator
+ model.to("cpu")
+ model.eval()
+
+ model = OnnxModel(model=model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"generator parameters: {num_param}")
+
+ suffix = f"epoch-{params.epoch}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ model_filename = params.exp_dir / f"vits-{suffix}.onnx"
+ export_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported generator to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ weight_type=QuantType.QUInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py
new file mode 120000
index 000000000..e65d91ea7
--- /dev/null
+++ b/egs/vctk/TTS/vits/flow.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/flow.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py
new file mode 120000
index 000000000..611679bfa
--- /dev/null
+++ b/egs/vctk/TTS/vits/generator.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/generator.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/hifigan.py b/egs/vctk/TTS/vits/hifigan.py
new file mode 120000
index 000000000..5ac025de7
--- /dev/null
+++ b/egs/vctk/TTS/vits/hifigan.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/hifigan.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py
new file mode 100755
index 000000000..06c25f02e
--- /dev/null
+++ b/egs/vctk/TTS/vits/infer.py
@@ -0,0 +1,272 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script performs model inference on test set.
+
+Usage:
+./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir ./vits/exp \
+ --max-duration 500
+"""
+
+
+import argparse
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import Dict, List
+
+import k2
+import torch
+import torch.nn as nn
+import torchaudio
+from tokenizer import Tokenizer
+from train import get_model, get_params
+from tts_datamodule import VctkTtsDataModule
+
+from icefall.checkpoint import load_checkpoint
+from icefall.utils import AttributeDict, setup_logger
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def infer_dataset(
+ dl: torch.utils.data.DataLoader,
+ subset: str,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+ speaker_map: Dict[str, int],
+) -> None:
+ """Decode dataset.
+ The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ tokenizer:
+ Used to convert text to phonemes.
+ """
+
+ # Background worker save audios to disk.
+ def _save_worker(
+ subset: str,
+ batch_size: int,
+ cut_ids: List[str],
+ audio: torch.Tensor,
+ audio_pred: torch.Tensor,
+ audio_lens: List[int],
+ audio_lens_pred: List[int],
+ ):
+ for i in range(batch_size):
+ torchaudio.save(
+ str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
+ audio[i : i + 1, : audio_lens[i]],
+ sample_rate=params.sampling_rate,
+ )
+ torchaudio.save(
+ str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
+ audio_pred[i : i + 1, : audio_lens_pred[i]],
+ sample_rate=params.sampling_rate,
+ )
+
+ device = next(model.parameters()).device
+ num_cuts = 0
+ log_interval = 5
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ futures = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ for batch_idx, batch in enumerate(dl):
+ batch_size = len(batch["tokens"])
+
+ tokens = batch["tokens"]
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+ speakers = (
+ torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
+ .int()
+ .to(device)
+ )
+
+ audio = batch["audio"]
+ audio_lens = batch["audio_lens"].tolist()
+ cut_ids = [cut.id for cut in batch["cut"]]
+
+ audio_pred, _, durations = model.inference_batch(
+ text=tokens,
+ text_lengths=tokens_lens,
+ sids=speakers,
+ )
+ audio_pred = audio_pred.detach().cpu()
+ # convert to samples
+ audio_lens_pred = (
+ (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+ )
+
+ futures.append(
+ executor.submit(
+ _save_worker,
+ subset,
+ batch_size,
+ cut_ids,
+ audio,
+ audio_pred,
+ audio_lens,
+ audio_lens_pred,
+ )
+ )
+
+ num_cuts += batch_size
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ # return results
+ for f in futures:
+ f.result()
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ VctkTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.suffix = f"epoch-{params.epoch}"
+
+ params.res_dir = params.exp_dir / "infer" / params.suffix
+ params.save_wav_dir = params.res_dir / "wav"
+ params.save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
+ logging.info("Infer started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ vctk = VctkTtsDataModule(args)
+ speaker_map = vctk.speakers()
+ params.num_spks = len(speaker_map)
+
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model.to(device)
+ model.eval()
+
+ num_param_g = sum([p.numel() for p in model.generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ test_cuts = vctk.test_cuts()
+ test_dl = vctk.test_dataloaders(test_cuts)
+
+ valid_cuts = vctk.valid_cuts()
+ valid_dl = vctk.valid_dataloaders(valid_cuts)
+
+ infer_sets = {"test": test_dl, "valid": valid_dl}
+
+ for subset, dl in infer_sets.items():
+ save_wav_dir = params.res_dir / "wav" / subset
+ save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
+
+ infer_dataset(
+ dl=dl,
+ subset=subset,
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ speaker_map=speaker_map,
+ )
+
+ logging.info(f"Wav files are saved to {params.save_wav_dir}")
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py
new file mode 120000
index 000000000..672e5ff68
--- /dev/null
+++ b/egs/vctk/TTS/vits/loss.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/loss.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/monotonic_align b/egs/vctk/TTS/vits/monotonic_align
new file mode 120000
index 000000000..71934e7cc
--- /dev/null
+++ b/egs/vctk/TTS/vits/monotonic_align
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/monotonic_align
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py
new file mode 120000
index 000000000..41d64a3a6
--- /dev/null
+++ b/egs/vctk/TTS/vits/posterior_encoder.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/posterior_encoder.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py
new file mode 120000
index 000000000..f979adbf0
--- /dev/null
+++ b/egs/vctk/TTS/vits/residual_coupling.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/residual_coupling.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py
new file mode 100755
index 000000000..757e67fc1
--- /dev/null
+++ b/egs/vctk/TTS/vits/test_onnx.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script is used to test the exported onnx model by vits/export-onnx.py
+
+Use the onnx model to generate a wav:
+./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+"""
+
+
+import argparse
+import logging
+from pathlib import Path
+
+import onnxruntime as ort
+import torch
+import torchaudio
+from tokenizer import Tokenizer
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the onnx model.",
+ )
+
+ parser.add_argument(
+ "--speakers",
+ type=Path,
+ default=Path("data/speakers.txt"),
+ help="Path to speakers.txt file.",
+ )
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(self, model_filename: str):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
+
+ def __call__(
+ self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Args:
+ tokens:
+ A 1-D tensor of shape (1, T)
+ Returns:
+ A tensor of shape (1, T')
+ """
+ noise_scale = torch.tensor([0.667], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
+ alpha = torch.tensor([1.0], dtype=torch.float32)
+
+ out = self.model.run(
+ [
+ self.model.get_outputs()[0].name,
+ ],
+ {
+ self.model.get_inputs()[0].name: tokens.numpy(),
+ self.model.get_inputs()[1].name: tokens_lens.numpy(),
+ self.model.get_inputs()[2].name: noise_scale.numpy(),
+ self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
+ self.model.get_inputs()[4].name: speaker.numpy(),
+ self.model.get_inputs()[5].name: alpha.numpy(),
+ },
+ )[0]
+ return torch.from_numpy(out)
+
+
+def main():
+ args = get_parser().parse_args()
+
+ tokenizer = Tokenizer(args.tokens)
+
+ with open(args.speakers) as f:
+ speaker_map = {line.strip(): i for i, line in enumerate(f)}
+ args.num_spks = len(speaker_map)
+
+ logging.info("About to create onnx model")
+ model = OnnxModel(args.model_filename)
+
+ text = "I went there to see the land, the people and how their system works, end quote."
+ tokens = tokenizer.texts_to_token_ids([text])
+ tokens = torch.tensor(tokens) # (1, T)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
+ speaker = torch.tensor([1], dtype=torch.int64) # (1, )
+ audio = model(tokens, tokens_lens, speaker) # (1, T')
+
+ torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
+ logging.info("Saved to test_onnx.wav")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py
new file mode 120000
index 000000000..0efba277e
--- /dev/null
+++ b/egs/vctk/TTS/vits/text_encoder.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/text_encoder.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py
new file mode 120000
index 000000000..057b0dc4b
--- /dev/null
+++ b/egs/vctk/TTS/vits/tokenizer.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/tokenizer.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py
new file mode 100755
index 000000000..56f167a17
--- /dev/null
+++ b/egs/vctk/TTS/vits/train.py
@@ -0,0 +1,1000 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import numpy as np
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from tokenizer import Tokenizer
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+from tts_datamodule import VctkTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1000,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ parser.add_argument(
+ "--lr", type=float, default=2.0e-4, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=20,
+ help="""Save checkpoint after processing this number of epochs"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.cur_epoch % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
+ Since it will take around 1000 epochs, we suggest using a large
+ save_every_n to save disk space.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ # training params
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": -1, # 0
+ "log_interval": 50,
+ "valid_interval": 200,
+ "env_info": get_env_info(),
+ "sampling_rate": 22050,
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
+ "n_mels": 80,
+ "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
+ "lambda_mel": 45.0, # loss scaling coefficient for Mel loss
+ "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
+ "lambda_dur": 1.0, # loss scaling coefficient for duration loss
+ "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict, model: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(filename, model=model)
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ mel_loss_params = {
+ "n_mels": params.n_mels,
+ "frame_length": params.frame_length,
+ "frame_shift": params.frame_shift,
+ }
+ generator_params = {
+ "hidden_channels": 192,
+ "spks": params.num_spks,
+ "langs": None,
+ "spk_embed_dim": None,
+ "global_channels": 256,
+ "segment_size": 32,
+ "text_encoder_attention_heads": 2,
+ "text_encoder_ffn_expand": 4,
+ "text_encoder_cnn_module_kernel": 5,
+ "text_encoder_blocks": 6,
+ "text_encoder_dropout_rate": 0.1,
+ "decoder_kernel_size": 7,
+ "decoder_channels": 512,
+ "decoder_upsample_scales": [8, 8, 2, 2],
+ "decoder_upsample_kernel_sizes": [16, 16, 4, 4],
+ "decoder_resblock_kernel_sizes": [3, 7, 11],
+ "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "use_weight_norm_in_decoder": True,
+ "posterior_encoder_kernel_size": 5,
+ "posterior_encoder_layers": 16,
+ "posterior_encoder_stacks": 1,
+ "posterior_encoder_base_dilation": 1,
+ "posterior_encoder_dropout_rate": 0.0,
+ "use_weight_norm_in_posterior_encoder": True,
+ "flow_flows": 4,
+ "flow_kernel_size": 5,
+ "flow_base_dilation": 1,
+ "flow_layers": 4,
+ "flow_dropout_rate": 0.0,
+ "use_weight_norm_in_flow": True,
+ "use_only_mean_in_flow": True,
+ "stochastic_duration_predictor_kernel_size": 3,
+ "stochastic_duration_predictor_dropout_rate": 0.5,
+ "stochastic_duration_predictor_flows": 4,
+ "stochastic_duration_predictor_dds_conv_layers": 3,
+ }
+ model = VITS(
+ vocab_size=params.vocab_size,
+ feature_dim=params.feature_dim,
+ sampling_rate=params.sampling_rate,
+ generator_params=generator_params,
+ mel_loss_params=mel_loss_params,
+ lambda_adv=params.lambda_adv,
+ lambda_mel=params.lambda_mel,
+ lambda_feat_match=params.lambda_feat_match,
+ lambda_dur=params.lambda_dur,
+ lambda_kl=params.lambda_kl,
+ )
+ return model
+
+
+def prepare_input(
+ batch: dict,
+ tokenizer: Tokenizer,
+ device: torch.device,
+ speaker_map: Dict[str, int],
+):
+ """Parse batch data"""
+ audio = batch["audio"].to(device)
+ features = batch["features"].to(device)
+ audio_lens = batch["audio_lens"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ tokens = batch["tokens"]
+ speakers = (
+ torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
+ )
+
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # a tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ optimizer_g: Optimizer,
+ optimizer_d: Optimizer,
+ scheduler_g: LRSchedulerType,
+ scheduler_d: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ speaker_map: Dict[str, int],
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ tokenizer:
+ Used to convert text to phonemes.
+ optimizer_g:
+ The optimizer for generator.
+ optimizer_d:
+ The optimizer for discriminator.
+ scheduler_g:
+ The learning rate scheduler for generator, we call step() every epoch.
+ scheduler_d:
+ The learning rate scheduler for discriminator, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ params=params,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["tokens"])
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ speakers,
+ ) = prepare_input(batch, tokenizer, device, speaker_map)
+
+ loss_info = MetricsTracker()
+ loss_info["samples"] = batch_size
+
+ try:
+ with autocast(enabled=params.use_fp16):
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=False,
+ )
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+ # update discriminator
+ optimizer_d.zero_grad()
+ scaler.scale(loss_d).backward()
+ scaler.step(optimizer_d)
+
+ with autocast(enabled=params.use_fp16):
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=True,
+ return_sample=params.batch_idx_train % params.log_interval == 0,
+ )
+ for k, v in stats_g.items():
+ if "returned_sample" not in k:
+ loss_info[k] = v * batch_size
+ # update generator
+ optimizer_g.zero_grad()
+ scaler.scale(loss_g).backward()
+ scaler.step(optimizer_g)
+ scaler.update()
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+ except: # noqa
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (
+ cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr_g = max(scheduler_g.get_last_lr())
+ cur_lr_d = max(scheduler_d.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate_g", cur_lr_g, params.batch_idx_train
+ )
+ tb_writer.add_scalar(
+ "train/learning_rate_d", cur_lr_d, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+ if "returned_sample" in stats_g:
+ speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
+ tb_writer.add_audio(
+ "train/speech_hat_",
+ speech_hat_,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+ tb_writer.add_audio(
+ "train/speech_",
+ speech_,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+ tb_writer.add_image(
+ "train/mel_hat_",
+ plot_feature(mel_hat_),
+ params.batch_idx_train,
+ dataformats="HWC",
+ )
+ tb_writer.add_image(
+ "train/mel_",
+ plot_feature(mel_),
+ params.batch_idx_train,
+ dataformats="HWC",
+ )
+
+ if (
+ params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info, (speech_hat, speech) = compute_validation_loss(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ speaker_map=speaker_map,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech_hat",
+ speech_hat,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech",
+ speech,
+ params.batch_idx_train,
+ params.sampling_rate,
+ )
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ speaker_map: Dict[str, int],
+ world_size: int = 1,
+ rank: int = 0,
+) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
+ """Run the validation process."""
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+ returned_sample = None
+
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(valid_dl):
+ batch_size = len(batch["tokens"])
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ speakers,
+ ) = prepare_input(batch, tokenizer, device, speaker_map)
+
+ loss_info = MetricsTracker()
+ loss_info["samples"] = batch_size
+
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=False,
+ )
+ assert loss_d.requires_grad is False
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=True,
+ )
+ assert loss_g.requires_grad is False
+ for k, v in stats_g.items():
+ loss_info[k] = v * batch_size
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+
+ # infer for first batch:
+ if batch_idx == 0 and rank == 0:
+ inner_model = model.module if isinstance(model, DDP) else model
+ audio_pred, _, duration = inner_model.inference(
+ text=tokens[0, : tokens_lens[0].item()],
+ sids=speakers[0],
+ )
+ audio_pred = audio_pred.data.cpu().numpy()
+ audio_len_pred = (
+ (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ )
+ assert audio_len_pred == len(audio_pred), (
+ audio_len_pred,
+ len(audio_pred),
+ )
+ audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
+ returned_sample = (audio_pred, audio_gt)
+
+ if world_size > 1:
+ tot_loss.reduce(device)
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss, returned_sample
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ tokenizer: Tokenizer,
+ optimizer_g: torch.optim.Optimizer,
+ optimizer_d: torch.optim.Optimizer,
+ speaker_map: Dict[str, int],
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ (
+ audio,
+ audio_lens,
+ features,
+ features_lens,
+ tokens,
+ tokens_lens,
+ speakers,
+ ) = prepare_input(batch, tokenizer, device, speaker_map)
+ try:
+ # for discriminator
+ with autocast(enabled=params.use_fp16):
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=False,
+ )
+ optimizer_d.zero_grad()
+ loss_d.backward()
+ # for generator
+ with autocast(enabled=params.use_fp16):
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ sids=speakers,
+ forward_generator=True,
+ )
+ optimizer_g.zero_grad()
+ loss_g.backward()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ vctk = VctkTtsDataModule(args)
+
+ train_cuts = vctk.train_cuts()
+ speaker_map = vctk.speakers()
+ params.num_spks = len(speaker_map)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+ generator = model.generator
+ discriminator = model.discriminator
+
+ num_param_g = sum([p.numel() for p in generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer_g = torch.optim.AdamW(
+ generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+ optimizer_d = torch.optim.AdamW(
+ discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
+
+ if checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer_g" in checkpoints:
+ logging.info("Loading optimizer_g state dict")
+ optimizer_g.load_state_dict(checkpoints["optimizer_g"])
+ if "optimizer_d" in checkpoints:
+ logging.info("Loading optimizer_d state dict")
+ optimizer_d.load_state_dict(checkpoints["optimizer_d"])
+
+ # load state_dict for schedulers
+ if "scheduler_g" in checkpoints:
+ logging.info("Loading scheduler_g state dict")
+ scheduler_g.load_state_dict(checkpoints["scheduler_g"])
+ if "scheduler_d" in checkpoints:
+ logging.info("Loading scheduler_d state dict")
+ scheduler_d.load_state_dict(checkpoints["scheduler_d"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ train_dl = vctk.train_dataloaders(train_cuts)
+
+ valid_cuts = vctk.valid_cuts()
+ valid_dl = vctk.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ speaker_map=speaker_map,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ speaker_map=speaker_map,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ # step per epoch
+ scheduler_g.step()
+ scheduler_d.step()
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ VctkTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py
new file mode 120000
index 000000000..962647408
--- /dev/null
+++ b/egs/vctk/TTS/vits/transform.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/transform.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py
new file mode 100644
index 000000000..8b2a96b09
--- /dev/null
+++ b/egs/vctk/TTS/vits/tts_datamodule.py
@@ -0,0 +1,338 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+ SpeechSynthesisDataset,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class VctkTtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/spectrogram"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--speakers",
+ type=Path,
+ default=Path("data/speakers.txt"),
+ help="Path to speakers.txt file.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=8,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get validation cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
+
+ @lru_cache()
+ def speakers(self) -> Dict[str, int]:
+ logging.info("About to get speakers")
+ with open(self.args.speakers) as f:
+ speakers = {line.strip(): i for i, line in enumerate(f)}
+ return speakers
diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py
new file mode 120000
index 000000000..085e764b4
--- /dev/null
+++ b/egs/vctk/TTS/vits/utils.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/utils.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py
new file mode 120000
index 000000000..1f58cf6fe
--- /dev/null
+++ b/egs/vctk/TTS/vits/vits.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/vits.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py
new file mode 120000
index 000000000..28f0a78ee
--- /dev/null
+++ b/egs/vctk/TTS/vits/wavenet.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/vits/wavenet.py
\ No newline at end of file
diff --git a/requirements-tts.txt b/requirements-tts.txt
new file mode 100644
index 000000000..c30e23d54
--- /dev/null
+++ b/requirements-tts.txt
@@ -0,0 +1,6 @@
+# for TTS recipes
+matplotlib==3.8.2
+cython==3.0.6
+numba==0.58.1
+g2p_en==2.1.0
+espnet_tts_frontend==0.0.3
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 9502fcbd2..a1a46ae64 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,5 @@ tensorboard
typeguard
dill
black==22.3.0
+onnx==1.15.0
+onnxruntime==1.16.3
\ No newline at end of file
From b87ed26c09e9f5bb29174dd01f13670fb6124583 Mon Sep 17 00:00:00 2001
From: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
Date: Wed, 6 Dec 2023 14:33:45 +0800
Subject: [PATCH 011/123] Normalize dockerfile (#1400)
---
docker/torch1.12.1-cuda11.3.dockerfile | 3 +--
docker/torch1.13.0-cuda11.6.dockerfile | 3 +--
docker/torch1.9.0-cuda10.2.dockerfile | 3 +--
docker/torch2.0.0-cuda11.7.dockerfile | 3 +--
docker/torch2.1.0-cuda11.8.dockerfile | 3 +--
docker/torch2.1.0-cuda12.1.dockerfile | 3 +--
6 files changed, 6 insertions(+), 12 deletions(-)
diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile
index ed746abe3..deb5715cc 100644
--- a/docker/torch1.12.1-cuda11.3.dockerfile
+++ b/docker/torch1.12.1-cuda11.3.dockerfile
@@ -18,7 +18,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
- libssl-dev \
+ libssl-dev \
autoconf \
automake \
bzip2 \
@@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
- \
kaldi_native_io \
kaldialign \
kaldifst \
diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile
index 9657866e5..afc6c1b84 100644
--- a/docker/torch1.13.0-cuda11.6.dockerfile
+++ b/docker/torch1.13.0-cuda11.6.dockerfile
@@ -18,7 +18,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
- libssl-dev \
+ libssl-dev \
autoconf \
automake \
bzip2 \
@@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
- \
kaldi_native_io \
kaldialign \
kaldifst \
diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile
index a92af7ad0..9ff225b54 100644
--- a/docker/torch1.9.0-cuda10.2.dockerfile
+++ b/docker/torch1.9.0-cuda10.2.dockerfile
@@ -25,7 +25,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
- libssl-dev \
+ libssl-dev \
autoconf \
automake \
bzip2 \
@@ -58,7 +58,6 @@ RUN pip uninstall -y tqdm && \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
- \
kaldi_native_io \
kaldialign \
kaldifst \
diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile
index 07296e6f0..db8076560 100644
--- a/docker/torch2.0.0-cuda11.7.dockerfile
+++ b/docker/torch2.0.0-cuda11.7.dockerfile
@@ -18,7 +18,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
- libssl-dev \
+ libssl-dev \
autoconf \
automake \
bzip2 \
@@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
- \
kaldi_native_io \
kaldialign \
kaldifst \
diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile
index e500e9a6a..b006b0d96 100644
--- a/docker/torch2.1.0-cuda11.8.dockerfile
+++ b/docker/torch2.1.0-cuda11.8.dockerfile
@@ -18,7 +18,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
- libssl-dev \
+ libssl-dev \
autoconf \
automake \
bzip2 \
@@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
- \
kaldi_native_io \
kaldialign \
kaldifst \
diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile
index c3f12323e..1b078dc22 100644
--- a/docker/torch2.1.0-cuda12.1.dockerfile
+++ b/docker/torch2.1.0-cuda12.1.dockerfile
@@ -18,7 +18,7 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
- libssl-dev \
+ libssl-dev \
autoconf \
automake \
bzip2 \
@@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
- \
kaldi_native_io \
kaldialign \
kaldifst \
From bda72f86fffe591d334630da522dba4cf5c66341 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Fri, 8 Dec 2023 06:32:40 +0800
Subject: [PATCH 012/123] minor adjustments to the VITS recipes for onnx
runtime (#1405)
---
egs/ljspeech/TTS/vits/export-onnx.py | 4 ++--
egs/ljspeech/TTS/vits/test_onnx.py | 4 ++--
egs/vctk/TTS/vits/export-onnx.py | 4 ++--
egs/vctk/TTS/vits/test_onnx.py | 6 +++---
4 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index 2068adeea..bca6aec99 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -176,7 +176,7 @@ def export_model_onnx(
torch.onnx.export(
model,
- (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
+ (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
model_filename,
verbose=False,
opset_version=opset_version,
@@ -184,8 +184,8 @@ def export_model_onnx(
"tokens",
"tokens_lens",
"noise_scale",
- "noise_scale_dur",
"alpha",
+ "noise_scale_dur",
],
output_names=["audio"],
dynamic_axes={
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
index 686fee2a0..fcbc1d663 100755
--- a/egs/ljspeech/TTS/vits/test_onnx.py
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -92,8 +92,8 @@ class OnnxModel:
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
- self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
- self.model.get_inputs()[4].name: alpha.numpy(),
+ self.model.get_inputs()[3].name: alpha.numpy(),
+ self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
},
)[0]
return torch.from_numpy(out)
diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py
index 7c9664cc1..cfc74fd0a 100755
--- a/egs/vctk/TTS/vits/export-onnx.py
+++ b/egs/vctk/TTS/vits/export-onnx.py
@@ -187,7 +187,7 @@ def export_model_onnx(
torch.onnx.export(
model,
- (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
+ (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker),
model_filename,
verbose=False,
opset_version=opset_version,
@@ -195,9 +195,9 @@ def export_model_onnx(
"tokens",
"tokens_lens",
"noise_scale",
+ "alpha",
"noise_scale_dur",
"speaker",
- "alpha",
],
output_names=["audio"],
dynamic_axes={
diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py
index 757e67fc1..d85c0a27b 100755
--- a/egs/vctk/TTS/vits/test_onnx.py
+++ b/egs/vctk/TTS/vits/test_onnx.py
@@ -101,9 +101,9 @@ class OnnxModel:
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
- self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
- self.model.get_inputs()[4].name: speaker.numpy(),
- self.model.get_inputs()[5].name: alpha.numpy(),
+ self.model.get_inputs()[3].name: alpha.numpy(),
+ self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
+ self.model.get_inputs()[5].name: speaker.numpy(),
},
)[0]
return torch.from_numpy(out)
From e9ec827de76856e38af7a884b878ca3a84f64bb9 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Fri, 8 Dec 2023 14:29:24 +0800
Subject: [PATCH 013/123] Rename zipformer2 to zipformer_for_ncnn_export_only
to avoid confusion. (#1407)
---
.../do_not_use_it_directly.py | 2 +-
.../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 -
.../zipformer_for_ncnn_export_only.py | 1 +
.../do_not_use_it_directly.py | 2 +-
.../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 -
.../zipformer_for_ncnn_export_only.py | 1 +
.../do_not_use_it_directly.py | 2 +-
.../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 -
.../zipformer_for_ncnn_export_only.py | 1 +
.../do_not_use_it_directly.py | 2 +-
.../{zipformer2.py => zipformer_for_ncnn_export_only.py} | 0
.../pruned_transducer_stateless7_streaming_multi/zipformer2.py | 1 -
12 files changed, 7 insertions(+), 8 deletions(-)
delete mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
delete mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
delete mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{zipformer2.py => zipformer_for_ncnn_export_only.py} (100%)
delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 3c13c19c6..0fba3b58f 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -66,7 +66,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
-from zipformer2 import Zipformer
+from zipformer_for_ncnn_export_only import Zipformer
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
deleted file mode 120000
index 12dbda888..000000000
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
new file mode 120000
index 000000000..d301e1f9b
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 61a3f27db..0426bc9a3 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -67,7 +67,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
-from zipformer2 import Zipformer
+from zipformer_for_ncnn_export_only import Zipformer
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
deleted file mode 120000
index 12dbda888..000000000
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
\ No newline at end of file
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
new file mode 120000
index 000000000..d301e1f9b
--- /dev/null
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
\ No newline at end of file
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index acde72d80..685f6ece6 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -70,7 +70,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
-from zipformer2 import Zipformer
+from zipformer_for_ncnn_export_only import Zipformer
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
deleted file mode 120000
index 12dbda888..000000000
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
\ No newline at end of file
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
new file mode 120000
index 000000000..d301e1f9b
--- /dev/null
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index cd26db6f3..9a6d2155b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -69,7 +69,7 @@ from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
-from zipformer2 import Zipformer
+from zipformer_for_ncnn_export_only import Zipformer
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
similarity index 100%
rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py
deleted file mode 120000
index d3625f478..000000000
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py
+++ /dev/null
@@ -1 +0,0 @@
-../pruned_transducer_stateless7_streaming/zipformer2.py
\ No newline at end of file
From df56aff31ea9b95aa3d9672398a0771dcb8eacc5 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Fri, 8 Dec 2023 21:11:31 +0800
Subject: [PATCH 014/123] minor fixes to the vits onnx exportation scripts
(#1408)
---
egs/ljspeech/TTS/vits/export-onnx.py | 2 +-
egs/vctk/TTS/vits/export-onnx.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index bca6aec99..36a9de27f 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -115,8 +115,8 @@ class OnnxModel(nn.Module):
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
- noise_scale_dur: float = 0.8,
alpha: float = 1.0,
+ noise_scale_dur: float = 0.8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py
index cfc74fd0a..667ac284b 100755
--- a/egs/vctk/TTS/vits/export-onnx.py
+++ b/egs/vctk/TTS/vits/export-onnx.py
@@ -121,9 +121,9 @@ class OnnxModel(nn.Module):
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
+ alpha: float = 1.0,
noise_scale_dur: float = 0.8,
speaker: int = 20,
- alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
From b0f70c9d042da734a5df988b98412e5def6b8072 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sun, 10 Dec 2023 11:38:39 +0800
Subject: [PATCH 015/123] Fix torch.jit.script() export for
pruned_transducer_stateless2 (#1410)
---
egs/librispeech/ASR/pruned_transducer_stateless2/export.py | 2 ++
egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py | 1 +
.../ASR/pruned_transducer_stateless2/scaling_converter.py | 1 +
3 files changed, 4 insertions(+)
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py
create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
index e02afa892..e2db98f73 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py
@@ -49,6 +49,7 @@ from pathlib import Path
import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
@@ -198,6 +199,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..4f377cd01
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..3b667058d
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
From 20a82c9abf6664b645c28dd6b3d629cf85e40231 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Tue, 12 Dec 2023 18:13:26 +0800
Subject: [PATCH 016/123] first commit (#1411)
---
.github/scripts/multi-zh-hans.sh | 88 +++++++++++++++++++
.github/workflows/multi-zh-hans.yml | 84 ++++++++++++++++++
.../ASR/zipformer/export-onnx-streaming.py | 14 ++-
3 files changed, 182 insertions(+), 4 deletions(-)
create mode 100755 .github/scripts/multi-zh-hans.sh
create mode 100644 .github/workflows/multi-zh-hans.yml
diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh
new file mode 100755
index 000000000..4ede7f43e
--- /dev/null
+++ b/.github/scripts/multi-zh-hans.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+
+set -ex
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "pwd: $PWD"
+
+cd egs/multi_zh-hans/ASR
+
+repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
+log "Downloading pre-trained model from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+cd exp/
+git lfs pull --include pretrained.pt
+rm -fv epoch-20.pt
+rm -fv *.onnx
+ln -s pretrained.pt epoch-20.pt
+cd ../data/lang_bpe_2000
+git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
+popd
+
+log "----------------------------------------"
+log "Export streaming ONNX transducer models "
+log "----------------------------------------"
+
+./zipformer/export-onnx-streaming.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ --causal 1 \
+ --avg 1 \
+ --epoch 20 \
+ --use-averaged-model 0 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --use-ctc 0
+
+ls -lh $repo/exp
+
+log "------------------------------------------------------------"
+log "Test export streaming ONNX transducer models (Python code) "
+log "------------------------------------------------------------"
+
+log "test fp32"
+./zipformer/onnx_pretrained-streaming.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav
+
+log "test int8"
+./zipformer/onnx_pretrained-streaming.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000000.wav
+
+log "Upload models to huggingface"
+git config --global user.name "k2-fsa"
+git config --global user.email "xxx@gmail.com"
+
+url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12
+GIT_LFS_SKIP_SMUDGE=1 git clone $url
+dst=$(basename $url)
+cp -v $repo/exp/*.onnx $dst
+cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
+mkdir -p $dst/test_wavs
+cp -v $repo/test_wavs/*.wav $dst/test_wavs
+cd $dst
+git lfs track "*.onnx"
+git add .
+git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
+
+log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
+rm -rf .git
+rm -fv .gitattributes
+cd ..
+tar cjfv $dst.tar.bz2 $dst
+mv -v $dst.tar.bz2 ../../../
diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml
new file mode 100644
index 000000000..439300b5f
--- /dev/null
+++ b/.github/workflows/multi-zh-hans.yml
@@ -0,0 +1,84 @@
+name: run-multi-zh-hans
+
+on:
+ push:
+ branches:
+ - master
+ - upload-ctc-model
+
+ pull_request:
+ branches:
+ - master
+
+ workflow_dispatch:
+
+concurrency:
+ group: run-multi-zh-hans-${{ github.ref }}
+ cancel-in-progress: true
+
+permissions:
+ contents: write
+
+jobs:
+ multi-zh-hans:
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest]
+ python-version: [3.8]
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+ cache-dependency-path: '**/requirements-ci.txt'
+
+ - name: Install Python dependencies
+ run: |
+ grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
+ pip uninstall -y protobuf
+ pip install --no-binary protobuf protobuf==3.20.*
+
+ - name: Cache kaldifeat
+ id: my-cache
+ uses: actions/cache@v2
+ with:
+ path: |
+ ~/tmp/kaldifeat
+ key: cache-tmp-${{ matrix.python-version }}-2023-05-22
+
+ - name: Install kaldifeat
+ if: steps.my-cache.outputs.cache-hit != 'true'
+ shell: bash
+ run: |
+ .github/scripts/install-kaldifeat.sh
+
+ - name: export-model
+ shell: bash
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ run: |
+ sudo apt-get -qq install git-lfs tree
+ export PYTHONPATH=$PWD:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
+ export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
+
+ .github/scripts/multi-zh-hans.sh
+ ls -lh
+
+ - name: upload model to https://github.com/k2-fsa/sherpa-onnx
+ uses: svenstaro/upload-release-action@v2
+ with:
+ file_glob: true
+ file: ./*.tar.bz2
+ overwrite: true
+ repo_name: k2-fsa/sherpa-onnx
+ repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
+ tag: asr-models
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
index e2c7d7d95..6bc9b1858 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
@@ -614,7 +614,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
- model.load_state_dict(average_checkpoints(filenames, device=device))
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@@ -625,7 +627,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
- model.load_state_dict(average_checkpoints(filenames, device=device))
+ model.load_state_dict(
+ average_checkpoints(filenames, device=device), strict=False
+ )
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
@@ -653,7 +657,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
- )
+ ),
+ strict=False,
)
else:
assert params.avg > 0, params.avg
@@ -671,7 +676,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
- )
+ ),
+ strict=False,
)
model.to("cpu")
From 9e9fe7954d13c5ee8f10e990b358cf8c752a24e6 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Tue, 12 Dec 2023 18:57:04 +0800
Subject: [PATCH 017/123] Upload gigaspeech zipformer models in CI (#1412)
---
.github/scripts/multi-zh-hans.sh | 3 +-
.../run-gigaspeech-zipformer-2023-10-17.sh | 74 +++++++++++++++++--
.github/workflows/multi-zh-hans.yml | 5 --
.../run-gigaspeech-zipformer-2023-10-17.yml | 14 ++++
4 files changed, 85 insertions(+), 11 deletions(-)
diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh
index 4ede7f43e..2dd1bce42 100755
--- a/.github/scripts/multi-zh-hans.sh
+++ b/.github/scripts/multi-zh-hans.sh
@@ -45,7 +45,7 @@ log "----------------------------------------"
ls -lh $repo/exp
log "------------------------------------------------------------"
-log "Test export streaming ONNX transducer models (Python code) "
+log "Test exported streaming ONNX transducer models (Python code)"
log "------------------------------------------------------------"
log "test fp32"
@@ -73,6 +73,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $url
dst=$(basename $url)
cp -v $repo/exp/*.onnx $dst
cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
+cp -v $repo/data/lang_bpe_2000/bpe.model $dst
mkdir -p $dst/test_wavs
cp -v $repo/test_wavs/*.wav $dst/test_wavs
cd $dst
diff --git a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
index 6bb0b9ebc..329896ef6 100755
--- a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
+++ b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
@@ -26,16 +26,80 @@ git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
+rm epoch-30.pt
+ln -s pretrained.pt epoch-30.pt
+rm *.onnx
+ls -lh
popd
+log "----------------------------------------"
+log "Export ONNX transducer models "
+log "----------------------------------------"
+
+./zipformer/export-onnx.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 30 \
+ --avg 1 \
+ --exp-dir $repo/exp
+
+ls -lh $repo/exp
+
+log "------------------------------------------------------------"
+log "Test exported ONNX transducer models (Python code) "
+log "------------------------------------------------------------"
+
+log "test fp32"
+./zipformer/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+log "test int8"
+./zipformer/onnx_pretrained.py \
+ --encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.int8.onnx \
+ --decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \
+ --joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.int8.onnx \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+log "Upload models to huggingface"
+git config --global user.name "k2-fsa"
+git config --global user.email "xxx@gmail.com"
+
+url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-gigaspeech-2023-12-12
+GIT_LFS_SKIP_SMUDGE=1 git clone $url
+dst=$(basename $url)
+cp -v $repo/exp/*.onnx $dst
+cp -v $repo/data/lang_bpe_500/tokens.txt $dst
+cp -v $repo/data/lang_bpe_500/bpe.model $dst
+mkdir -p $dst/test_wavs
+cp -v $repo/test_wavs/*.wav $dst/test_wavs
+cd $dst
+git lfs track "*.onnx"
+git add .
+git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
+
+log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
+rm -rf .git
+rm -fv .gitattributes
+cd ..
+tar cjfv $dst.tar.bz2 $dst
+ls -lh
+mv -v $dst.tar.bz2 ../../../
+
log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
+ --epoch 30 \
--avg 1 \
--jit 1
@@ -67,7 +131,7 @@ echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p zipformer/exp
- ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
+ ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
ls -lh data
@@ -83,7 +147,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" ==
./zipformer/decode.py \
--decoding-method $method \
- --epoch 999 \
+ --epoch 30 \
--avg 1 \
--use-averaged-model 0 \
--max-duration $max_duration \
diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml
index 439300b5f..9081047de 100644
--- a/.github/workflows/multi-zh-hans.yml
+++ b/.github/workflows/multi-zh-hans.yml
@@ -2,11 +2,6 @@ name: run-multi-zh-hans
on:
push:
- branches:
- - master
- - upload-ctc-model
-
- pull_request:
branches:
- master
diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml
index 7572f4b5f..87090e310 100644
--- a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml
+++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml
@@ -21,6 +21,7 @@ on:
push:
branches:
- master
+
pull_request:
types: [labeled]
@@ -33,6 +34,8 @@ on:
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
+ workflow_dispatch:
+
concurrency:
group: run_gigaspeech_2023_10_17_zipformer-${{ github.ref }}
cancel-in-progress: true
@@ -85,6 +88,7 @@ jobs:
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
mkdir -p egs/gigaspeech/ASR/data
ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank
@@ -97,6 +101,16 @@ jobs:
.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh
+ - name: upload model to https://github.com/k2-fsa/sherpa-onnx
+ uses: svenstaro/upload-release-action@v2
+ with:
+ file_glob: true
+ file: ./*.tar.bz2
+ overwrite: true
+ repo_name: k2-fsa/sherpa-onnx
+ repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
+ tag: asr-models
+
- name: Display decoding results for gigaspeech zipformer
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
From d0da509055468f600808a3d7cd8885649d12b96d Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Wed, 13 Dec 2023 10:33:28 +0800
Subject: [PATCH 018/123] Support ONNX export for Streaming CTC Encoder (#1413)
* Create export-onnx-streaming-ctc.py
* doc_str updated
Co-authored-by: Fangjun Kuang
---------
Co-authored-by: Fangjun Kuang
---
.../zipformer/export-onnx-streaming-ctc.py | 570 ++++++++++++++++++
1 file changed, 570 insertions(+)
create mode 100755 egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
new file mode 100755
index 000000000..3c0f74005
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -0,0 +1,570 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang, Zengrui Jin)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script exports a CTC model from PyTorch to ONNX.
+
+
+1. Download the pre-trained streaming model with CTC head
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx-streaming-ctc.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --num-encoder-layers "2,2,3,4,3,2" \
+ --downsampling-factor "1,2,4,8,4,2" \
+ --feedforward-dim "512,768,1024,1536,1024,768" \
+ --num-heads "4,4,4,8,4,4" \
+ --encoder-dim "192,256,384,512,384,256" \
+ --query-head-dim 32 \
+ --value-head-dim 12 \
+ --pos-head-dim 4 \
+ --pos-dim 48 \
+ --encoder-unmasked-dim "192,192,256,256,256,192" \
+ --cnn-module-kernel "31,31,15,15,15,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --causal True \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --use-ctc 1
+
+The --chunk-size in training is "16,32,64,-1", so we select one of them
+(excluding -1) during streaming export. The same applies to `--left-context`,
+whose value is "64,128,256,-1".
+
+It will generate the following file inside $repo/exp:
+
+ - ctc-epoch-99-avg-1-chunk-16-left-64.onnx
+
+See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import k2
+import onnx
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_model, get_params
+from zipformer import Zipformer2
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import num_tokens, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for averaging.
+ Note: Epoch counts from 0.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxModel(nn.Module):
+ """A wrapper for Zipformer and the ctc_head"""
+
+ def __init__(
+ self,
+ encoder: Zipformer2,
+ encoder_embed: nn.Module,
+ ctc_output: nn.Module,
+ ):
+ """
+ Args:
+ encoder:
+ A Zipformer encoder.
+ encoder_proj:
+ The projection layer for encoder from the joiner.
+ ctc_output:
+ The ctc head.
+ """
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+ self.ctc_output = ctc_output
+ self.chunk_size = encoder.chunk_size[0]
+ self.left_context_len = encoder.left_context_frames[0]
+ self.pad_length = 7 + 2 * 3
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ states: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
+ N = x.size(0)
+ T = self.chunk_size * 2 + self.pad_length
+ x_lens = torch.tensor([T] * N, device=x.device)
+ left_context_len = self.left_context_len
+
+ cached_embed_left_pad = states[-2]
+ x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
+
+ src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2)
+ encoder_states = states[:-2]
+ logging.info(f"len_encoder_states={len(encoder_states)}")
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = self.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2)
+ encoder_out = self.ctc_output(encoder_out)
+ # Now encoder_out is of shape (N, T, ctc_output_dim)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+
+ return encoder_out, new_states
+
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = self.encoder.get_init_states(batch_size, device)
+
+ embed_states = self.encoder_embed.get_init_states(batch_size, device)
+
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def export_streaming_ctc_model_onnx(
+ model: OnnxModel,
+ encoder_filename: str,
+ opset_version: int = 11,
+) -> None:
+ model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
+
+ decode_chunk_len = model.chunk_size * 2
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ T = decode_chunk_len + model.pad_length
+
+ x = torch.rand(1, T, 80, dtype=torch.float32)
+ init_state = model.get_init_states()
+ num_encoders = len(model.encoder.encoder_dim)
+ logging.info(f"num_encoders: {num_encoders}")
+ logging.info(f"len(init_state): {len(init_state)}")
+
+ inputs = {}
+ input_names = ["x"]
+
+ outputs = {}
+ output_names = ["log_probs"]
+
+ def build_inputs_outputs(tensors, i):
+ assert len(tensors) == 6, len(tensors)
+
+ # (downsample_left, batch_size, key_dim)
+ name = f"cached_key_{i}"
+ logging.info(f"{name}.shape: {tensors[0].shape}")
+ inputs[name] = {1: "N"}
+ outputs[f"new_{name}"] = {1: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ # (1, batch_size, downsample_left, nonlin_attn_head_dim)
+ name = f"cached_nonlin_attn_{i}"
+ logging.info(f"{name}.shape: {tensors[1].shape}")
+ inputs[name] = {1: "N"}
+ outputs[f"new_{name}"] = {1: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val1_{i}"
+ logging.info(f"{name}.shape: {tensors[2].shape}")
+ inputs[name] = {1: "N"}
+ outputs[f"new_{name}"] = {1: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val2_{i}"
+ logging.info(f"{name}.shape: {tensors[3].shape}")
+ inputs[name] = {1: "N"}
+ outputs[f"new_{name}"] = {1: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv1_{i}"
+ logging.info(f"{name}.shape: {tensors[4].shape}")
+ inputs[name] = {0: "N"}
+ outputs[f"new_{name}"] = {0: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv2_{i}"
+ logging.info(f"{name}.shape: {tensors[5].shape}")
+ inputs[name] = {0: "N"}
+ outputs[f"new_{name}"] = {0: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ num_encoder_layers = ",".join(map(str, model.encoder.num_encoder_layers))
+ encoder_dims = ",".join(map(str, model.encoder.encoder_dim))
+ cnn_module_kernels = ",".join(map(str, model.encoder.cnn_module_kernel))
+ ds = model.encoder.downsampling_factor
+ left_context_len = model.left_context_len
+ left_context_len = [left_context_len // k for k in ds]
+ left_context_len = ",".join(map(str, left_context_len))
+ query_head_dims = ",".join(map(str, model.encoder.query_head_dim))
+ value_head_dims = ",".join(map(str, model.encoder.value_head_dim))
+ num_heads = ",".join(map(str, model.encoder.num_heads))
+
+ meta_data = {
+ "model_type": "zipformer2",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "streaming ctc zipformer2",
+ "decode_chunk_len": str(decode_chunk_len), # 32
+ "T": str(T), # 32+7+2*3=45
+ "num_encoder_layers": num_encoder_layers,
+ "encoder_dims": encoder_dims,
+ "cnn_module_kernels": cnn_module_kernels,
+ "left_context_len": left_context_len,
+ "query_head_dims": query_head_dims,
+ "value_head_dims": value_head_dims,
+ "num_heads": num_heads,
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ for i in range(len(init_state[:-2]) // 6):
+ build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
+
+ # (batch_size, channels, left_pad, freq)
+ embed_states = init_state[-2]
+ name = "embed_states"
+ logging.info(f"{name}.shape: {embed_states.shape}")
+ inputs[name] = {0: "N"}
+ outputs[f"new_{name}"] = {0: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ # (batch_size,)
+ processed_lens = init_state[-1]
+ name = "processed_lens"
+ logging.info(f"{name}.shape: {processed_lens.shape}")
+ inputs[name] = {0: "N"}
+ outputs[f"new_{name}"] = {0: "N"}
+ input_names.append(name)
+ output_names.append(f"new_{name}")
+
+ logging.info(inputs)
+ logging.info(outputs)
+ logging.info(input_names)
+ logging.info(output_names)
+
+ torch.onnx.export(
+ model,
+ (x, init_state),
+ encoder_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes={
+ "x": {0: "N"},
+ "log_probs": {0: "N"},
+ **inputs,
+ **outputs,
+ },
+ )
+
+ add_meta_data(filename=encoder_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ convert_scaled_to_non_scaled(model, inplace=True)
+
+ model = OnnxModel(
+ encoder=model.encoder,
+ encoder_embed=model.encoder_embed,
+ ctc_output=model.ctc_output,
+ )
+
+ total_num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"total parameters: {total_num_param}")
+
+ if params.iter > 0:
+ suffix = f"iter-{params.iter}"
+ else:
+ suffix = f"epoch-{params.epoch}"
+
+ suffix += f"-avg-{params.avg}"
+ suffix += f"-chunk-{params.chunk_size}"
+ suffix += f"-left-{params.left_context_frames}"
+
+ opset_version = 13
+
+ logging.info("Exporting model")
+ model_filename = params.exp_dir / f"ctc-{suffix}.onnx"
+ export_streaming_ctc_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported model to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ op_types_to_quantize=["MatMul"],
+ weight_type=QuantType.QInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
From f85f0252a9f05c10e6782a693c87a1fede2f83c7 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Wed, 13 Dec 2023 17:34:12 +0800
Subject: [PATCH 019/123] Add greedy search for streaming zipformer CTC.
(#1415)
---
.github/scripts/multi-zh-hans.sh | 79 +++-
.../onnx_pretrained-streaming-ctc.py | 426 ++++++++++++++++++
.../zipformer/onnx_pretrained-streaming.py | 2 +-
.../zipformer/export-onnx-streaming-ctc.py | 1 +
.../onnx_pretrained-streaming-ctc.py | 1 +
5 files changed, 503 insertions(+), 6 deletions(-)
create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
create mode 120000 egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py
create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py
diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh
index 2dd1bce42..427d8887b 100755
--- a/.github/scripts/multi-zh-hans.sh
+++ b/.github/scripts/multi-zh-hans.sh
@@ -2,6 +2,10 @@
set -ex
+git config --global user.name "k2-fsa"
+git config --global user.email "csukuangfj@gmail.com"
+git config --global lfs.allowincompletepush true
+
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
@@ -24,9 +28,73 @@ rm -fv epoch-20.pt
rm -fv *.onnx
ln -s pretrained.pt epoch-20.pt
cd ../data/lang_bpe_2000
+ls -lh
git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model
+git lfs pull --include "*.model"
+ls -lh
popd
+log "----------------------------------------"
+log "Export streaming ONNX CTC models "
+log "----------------------------------------"
+./zipformer/export-onnx-streaming-ctc.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ --causal 1 \
+ --avg 1 \
+ --epoch 20 \
+ --use-averaged-model 0 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --use-ctc 1
+
+ls -lh $repo/exp/
+
+log "------------------------------------------------------------"
+log "Test exported streaming ONNX CTC models (greedy search) "
+log "------------------------------------------------------------"
+
+test_wavs=(
+DEV_T0000000000.wav
+DEV_T0000000001.wav
+DEV_T0000000002.wav
+TEST_MEETING_T0000000113.wav
+TEST_MEETING_T0000000219.wav
+TEST_MEETING_T0000000351.wav
+)
+
+for w in ${test_wavs[@]}; do
+ ./zipformer/onnx_pretrained-streaming-ctc.py \
+ --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/$w
+done
+
+log "Upload onnx CTC models to huggingface"
+url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13
+GIT_LFS_SKIP_SMUDGE=1 git clone $url
+dst=$(basename $url)
+cp -v $repo/exp/ctc*.onnx $dst
+cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
+cp -v $repo/data/lang_bpe_2000/bpe.model $dst
+mkdir -p $dst/test_wavs
+cp -v $repo/test_wavs/*.wav $dst/test_wavs
+cd $dst
+git lfs track "*.onnx" "bpe.model"
+ls -lh
+file bpe.model
+git status
+git add .
+git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
+
+log "Upload models to https://github.com/k2-fsa/sherpa-onnx"
+rm -rf .git
+rm -fv .gitattributes
+cd ..
+tar cjfv $dst.tar.bz2 $dst
+ls -lh *.tar.bz2
+mv -v $dst.tar.bz2 ../../../
+
log "----------------------------------------"
log "Export streaming ONNX transducer models "
log "----------------------------------------"
@@ -64,20 +132,20 @@ log "test int8"
--tokens $repo/data/lang_bpe_2000/tokens.txt \
$repo/test_wavs/DEV_T0000000000.wav
-log "Upload models to huggingface"
-git config --global user.name "k2-fsa"
-git config --global user.email "xxx@gmail.com"
+log "Upload onnx transducer models to huggingface"
url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12
GIT_LFS_SKIP_SMUDGE=1 git clone $url
dst=$(basename $url)
-cp -v $repo/exp/*.onnx $dst
+cp -v $repo/exp/encoder*.onnx $dst
+cp -v $repo/exp/decoder*.onnx $dst
+cp -v $repo/exp/joiner*.onnx $dst
cp -v $repo/data/lang_bpe_2000/tokens.txt $dst
cp -v $repo/data/lang_bpe_2000/bpe.model $dst
mkdir -p $dst/test_wavs
cp -v $repo/test_wavs/*.wav $dst/test_wavs
cd $dst
-git lfs track "*.onnx"
+git lfs track "*.onnx" bpe.model
git add .
git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true
@@ -86,4 +154,5 @@ rm -rf .git
rm -fv .gitattributes
cd ..
tar cjfv $dst.tar.bz2 $dst
+ls -lh *.tar.bz2
mv -v $dst.tar.bz2 ../../../
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
new file mode 100755
index 000000000..44546cae5
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
@@ -0,0 +1,426 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script loads ONNX models exported by ./export-onnx-streaming-ctc.py
+and uses them to decode waves.
+
+We use the pre-trained model from
+https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx-streaming-ctc.py \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --causal True \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --use-ctc 1
+
+It will generate the following 2 files inside $repo/exp:
+
+ - ctc-epoch-99-avg-1-chunk-16-left-128.int8.onnx
+ - ctc-epoch-99-avg-1-chunk-16-left-128.onnx
+
+You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
+
+3. Run this file with the exported ONNX models
+
+./zipformer/onnx_pretrained-streaming-ctc.py \
+ --model-filename $repo/exp/ctc-epoch-99-avg-1-chunk-16-left-128.onnx \
+ --tokens $repo/data/lang_bpe_2000/tokens.txt \
+ $repo/test_wavs/DEV_T0000000001.wav
+
+Note: Even though this script only supports decoding a single file,
+the exported ONNX models do support batch processing.
+"""
+
+import argparse
+import logging
+from typing import Dict, List, Tuple
+
+import numpy as np
+import onnxruntime as ort
+import torch
+import torchaudio
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the decoder onnx model. ",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ help="""Path to tokens.txt.""",
+ )
+
+ parser.add_argument(
+ "sound_file",
+ type=str,
+ help="The input sound file to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ model_filename: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 1
+
+ self.session_opts = session_opts
+
+ self.init_model(model_filename)
+
+ def init_model(self, model_filename: str):
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ self.init_states()
+
+ def init_states(self, batch_size: int = 1):
+ meta = self.model.get_modelmeta().custom_metadata_map
+ logging.info(f"meta={meta}")
+
+ model_type = meta["model_type"]
+ assert model_type == "zipformer2", model_type
+
+ decode_chunk_len = int(meta["decode_chunk_len"])
+ T = int(meta["T"])
+
+ num_encoder_layers = meta["num_encoder_layers"]
+ encoder_dims = meta["encoder_dims"]
+ cnn_module_kernels = meta["cnn_module_kernels"]
+ left_context_len = meta["left_context_len"]
+ query_head_dims = meta["query_head_dims"]
+ value_head_dims = meta["value_head_dims"]
+ num_heads = meta["num_heads"]
+
+ def to_int_list(s):
+ return list(map(int, s.split(",")))
+
+ num_encoder_layers = to_int_list(num_encoder_layers)
+ encoder_dims = to_int_list(encoder_dims)
+ cnn_module_kernels = to_int_list(cnn_module_kernels)
+ left_context_len = to_int_list(left_context_len)
+ query_head_dims = to_int_list(query_head_dims)
+ value_head_dims = to_int_list(value_head_dims)
+ num_heads = to_int_list(num_heads)
+
+ logging.info(f"decode_chunk_len: {decode_chunk_len}")
+ logging.info(f"T: {T}")
+ logging.info(f"num_encoder_layers: {num_encoder_layers}")
+ logging.info(f"encoder_dims: {encoder_dims}")
+ logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
+ logging.info(f"left_context_len: {left_context_len}")
+ logging.info(f"query_head_dims: {query_head_dims}")
+ logging.info(f"value_head_dims: {value_head_dims}")
+ logging.info(f"num_heads: {num_heads}")
+
+ num_encoders = len(num_encoder_layers)
+
+ self.states = []
+ for i in range(num_encoders):
+ num_layers = num_encoder_layers[i]
+ key_dim = query_head_dims[i] * num_heads[i]
+ embed_dim = encoder_dims[i]
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ value_dim = value_head_dims[i] * num_heads[i]
+ conv_left_pad = cnn_module_kernels[i] // 2
+
+ for layer in range(num_layers):
+ cached_key = torch.zeros(
+ left_context_len[i], batch_size, key_dim
+ ).numpy()
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, left_context_len[i], nonlin_attn_head_dim
+ ).numpy()
+ cached_val1 = torch.zeros(
+ left_context_len[i], batch_size, value_dim
+ ).numpy()
+ cached_val2 = torch.zeros(
+ left_context_len[i], batch_size, value_dim
+ ).numpy()
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
+ self.states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+ embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
+ self.states.append(embed_states)
+ processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
+ self.states.append(processed_lens)
+
+ self.num_encoders = num_encoders
+
+ self.segment = T
+ self.offset = decode_chunk_len
+
+ def _build_model_input_output(
+ self,
+ x: torch.Tensor,
+ ) -> Tuple[Dict[str, np.ndarray], List[str]]:
+ model_input = {"x": x.numpy()}
+ model_output = ["log_probs"]
+
+ def build_inputs_outputs(tensors, i):
+ assert len(tensors) == 6, len(tensors)
+
+ # (downsample_left, batch_size, key_dim)
+ name = f"cached_key_{i}"
+ model_input[name] = tensors[0]
+ model_output.append(f"new_{name}")
+
+ # (1, batch_size, downsample_left, nonlin_attn_head_dim)
+ name = f"cached_nonlin_attn_{i}"
+ model_input[name] = tensors[1]
+ model_output.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val1_{i}"
+ model_input[name] = tensors[2]
+ model_output.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val2_{i}"
+ model_input[name] = tensors[3]
+ model_output.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv1_{i}"
+ model_input[name] = tensors[4]
+ model_output.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv2_{i}"
+ model_input[name] = tensors[5]
+ model_output.append(f"new_{name}")
+
+ for i in range(len(self.states[:-2]) // 6):
+ build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
+
+ # (batch_size, channels, left_pad, freq)
+ name = "embed_states"
+ embed_states = self.states[-2]
+ model_input[name] = embed_states
+ model_output.append(f"new_{name}")
+
+ # (batch_size,)
+ name = "processed_lens"
+ processed_lens = self.states[-1]
+ model_input[name] = processed_lens
+ model_output.append(f"new_{name}")
+
+ return model_input, model_output
+
+ def _update_states(self, states: List[np.ndarray]):
+ self.states = states
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ Returns:
+ Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
+ where T' is usually equal to ((T-7)//2 - 3)//2
+ """
+ model_input, model_output_names = self._build_model_input_output(x)
+
+ out = self.model.run(model_output_names, model_input)
+
+ self._update_states(out[1:])
+
+ return torch.from_numpy(out[0])
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+def create_streaming_feature_extractor() -> OnlineFeature:
+ """Create a CPU streaming feature extractor.
+
+ At present, we assume it returns a fbank feature extractor with
+ fixed options. In the future, we will support passing in the options
+ from outside.
+
+ Returns:
+ Return a CPU streaming feature extractor.
+ """
+ opts = FbankOptions()
+ opts.device = "cpu"
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ return OnlineFbank(opts)
+
+
+def greedy_search(
+ log_probs: torch.Tensor,
+) -> List[int]:
+ """Greedy search for a single utterance.
+ Args:
+ log_probs:
+ A 3-D tensor of shape (1, T, vocab_size)
+ Returns:
+ Return the decoded result.
+ """
+ assert log_probs.ndim == 3, log_probs.shape
+ assert log_probs.shape[0] == 1, log_probs.shape
+
+ max_indexes = log_probs[0].argmax(dim=1)
+ unique_indexes = torch.unique_consecutive(max_indexes)
+
+ blank_id = 0
+ unique_indexes = unique_indexes[unique_indexes != blank_id]
+ return unique_indexes.tolist()
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ model = OnnxModel(model_filename=args.model_filename)
+
+ sample_rate = 16000
+
+ logging.info("Constructing Fbank computer")
+ online_fbank = create_streaming_feature_extractor()
+
+ logging.info(f"Reading sound files: {args.sound_file}")
+ waves = read_sound_files(
+ filenames=[args.sound_file],
+ expected_sample_rate=sample_rate,
+ )[0]
+
+ tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
+ wave_samples = torch.cat([waves, tail_padding])
+
+ num_processed_frames = 0
+ segment = model.segment
+ offset = model.offset
+
+ hyp = []
+
+ chunk = int(1 * sample_rate) # 1 second
+ start = 0
+ while start < wave_samples.numel():
+ end = min(start + chunk, wave_samples.numel())
+ samples = wave_samples[start:end]
+ start += chunk
+
+ online_fbank.accept_waveform(
+ sampling_rate=sample_rate,
+ waveform=samples,
+ )
+
+ while online_fbank.num_frames_ready - num_processed_frames >= segment:
+ frames = []
+ for i in range(segment):
+ frames.append(online_fbank.get_frame(num_processed_frames + i))
+ num_processed_frames += offset
+ frames = torch.cat(frames, dim=0)
+ frames = frames.unsqueeze(0)
+ log_probs = model(frames)
+
+ hyp += greedy_search(log_probs)
+
+ # To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes
+ id2token = {}
+ with open(args.tokens, encoding="utf-8") as f:
+ for line in f:
+ token, idx = line.split()
+ if token[:3] == "<0x" and token[-1] == ">":
+ token = int(token[1:-1], base=16)
+ assert 0 <= token < 256, token
+ token = token.to_bytes(1, byteorder="little")
+ else:
+ token = token.encode(encoding="utf-8")
+
+ id2token[int(idx)] = token
+
+ text = b""
+ for i in hyp:
+ text += id2token[i]
+ text = text.decode(encoding="utf-8")
+ text = text.replace("▁", " ").strip()
+
+ logging.info(args.sound_file)
+ logging.info(text)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
index e62491444..e7c4f40ee 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
@@ -326,7 +326,7 @@ class OnnxModel:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
- T' is usually equal to ((T-7)//2+1)//2
+ T' is usually equal to ((T-7)//2-3)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py
new file mode 120000
index 000000000..652346001
--- /dev/null
+++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
\ No newline at end of file
diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py
new file mode 120000
index 000000000..d623a8462
--- /dev/null
+++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
\ No newline at end of file
From 10a234709cb6b8aa5e99b1c18140b49db7b6faca Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Thu, 14 Dec 2023 11:26:37 +0800
Subject: [PATCH 020/123] bugs fixed (#1416)
---
egs/aishell4/ASR/prepare.sh | 2 +-
egs/alimeeting/ASR/prepare.sh | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh
index 361cc26ab..e8d9eb7b9 100755
--- a/egs/aishell4/ASR/prepare.sh
+++ b/egs/aishell4/ASR/prepare.sh
@@ -79,7 +79,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4
- lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4
+ ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/aishell4/.fbank.done
fi
fi
diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh
index 1709733c7..c8fed658d 100755
--- a/egs/alimeeting/ASR/prepare.sh
+++ b/egs/alimeeting/ASR/prepare.sh
@@ -7,6 +7,7 @@ set -eou pipefail
stage=-1
stop_stage=100
+perturb_speed=true
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@@ -68,7 +69,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process alimeeting"
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then
mkdir -p data/fbank/alimeeting
- lhotse prepare ali-meeting $dl_dir/alimeeting data/manifests/alimeeting
+ ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
fi
fi
From 702d4f59147a81ef7ba37c7863ae7bd258c743a9 Mon Sep 17 00:00:00 2001
From: TianHao Zhang <32243340+Zth9730@users.noreply.github.com>
Date: Thu, 21 Dec 2023 14:42:33 +0800
Subject: [PATCH 021/123] Update prepare.sh (#1422)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
fix the bug in line 251:
1、 del the additional blank
2、correct the spell error of "new_vocab_size"
---
egs/libriheavy/ASR/prepare.sh | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh
index af7e3c5b0..b0736c98b 100755
--- a/egs/libriheavy/ASR/prepare.sh
+++ b/egs/libriheavy/ASR/prepare.sh
@@ -248,7 +248,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
fi
for vocab_size in ${vocab_sizes[@]}; do
- new_vacab_size = $(($vocab_size + 256))
+ new_vocab_size=$(($vocab_size + 256))
lang_dir=data/lang_punc_bpe_${new_vocab_size}
mkdir -p $lang_dir
From 79a42148dbcd98c42586f8386d91f6f4bb8f9979 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sat, 23 Dec 2023 00:38:36 +0800
Subject: [PATCH 022/123] Add CI test to cover zipformer/train.py (#1424)
---
.github/scripts/docker/Dockerfile | 59 +++++++++++++++
.github/scripts/docker/run.sh | 60 +++++++++++++++
.github/workflows/build-cpu-docker.yml | 75 +++++++++++++++++++
.github/workflows/train-librispeech.yml | 56 ++++++++++++++
egs/gigaspeech/ASR/zipformer/my_profile.py | 1 +
egs/gigaspeech/ASR/zipformer/profile.py | 1 -
.../{profile.py => my_profile.py} | 2 +-
.../{profile.py => my_profile.py} | 2 +-
.../{profile.py => my_profile.py} | 2 +-
.../zipformer/{profile.py => my_profile.py} | 2 +-
egs/tedlium3/ASR/zipformer/my_profile.py | 1 +
egs/tedlium3/ASR/zipformer/profile.py | 1 -
12 files changed, 256 insertions(+), 6 deletions(-)
create mode 100644 .github/scripts/docker/Dockerfile
create mode 100755 .github/scripts/docker/run.sh
create mode 100644 .github/workflows/build-cpu-docker.yml
create mode 100644 .github/workflows/train-librispeech.yml
create mode 120000 egs/gigaspeech/ASR/zipformer/my_profile.py
delete mode 120000 egs/gigaspeech/ASR/zipformer/profile.py
rename egs/librispeech/ASR/pruned_transducer_stateless/{profile.py => my_profile.py} (98%)
rename egs/librispeech/ASR/pruned_transducer_stateless4/{profile.py => my_profile.py} (98%)
rename egs/librispeech/ASR/pruned_transducer_stateless7/{profile.py => my_profile.py} (98%)
rename egs/librispeech/ASR/zipformer/{profile.py => my_profile.py} (99%)
create mode 120000 egs/tedlium3/ASR/zipformer/my_profile.py
delete mode 120000 egs/tedlium3/ASR/zipformer/profile.py
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
new file mode 100644
index 000000000..55c3aa1b9
--- /dev/null
+++ b/.github/scripts/docker/Dockerfile
@@ -0,0 +1,59 @@
+ARG PYTHON_VERSION=3.8
+FROM python:${PYTHON_VERSION}
+
+ARG TORCHAUDIO_VERSION="0.13.0"
+ARG TORCH_VERSION="1.13.0"
+ARG K2_VERSION="1.24.4.dev20231220"
+ARG KALDIFEAT_VERSION="1.25.3.dev20231221"
+
+ARG _K2_VERSION="${K2_VERSION}+cpu.torch${TORCH_VERSION}"
+ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}"
+
+RUN apt-get update -y && \
+ apt-get install -qq -y \
+ ffmpeg \
+ git \
+ git-lfs \
+ less \
+ vim \
+ && \
+ apt-get clean && \
+ rm -rf /var/cache/apt/archives /var/lib/apt/lists
+
+
+LABEL authors="Fangjun Kuang "
+LABEL k2_version=${_K2_VERSION}
+LABEL kaldifeat_version=${_KALDIFEAT_VERSION}
+LABEL github_repo="https://github.com/k2-fsa/icefall"
+
+# Install dependencies
+RUN pip install --no-cache-dir \
+ torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \
+ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \
+ git+https://github.com/lhotse-speech/lhotse \
+ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
+ kaldi_native_io \
+ kaldialign \
+ kaldifst \
+ kaldilm \
+ sentencepiece>=0.1.96 \
+ tensorboard \
+ typeguard \
+ dill \
+ onnx \
+ onnxruntime \
+ onnxmltools \
+ six \
+ multi_quantization \
+ typeguard \
+ numpy \
+ pytest \
+ graphviz
+
+# RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
+# cd /workspace/icefall && \
+# pip install --no-cache-dir -r requirements.txt
+#
+# ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
+#
+# WORKDIR /workspace/icefall
diff --git a/.github/scripts/docker/run.sh b/.github/scripts/docker/run.sh
new file mode 100755
index 000000000..aeb80b330
--- /dev/null
+++ b/.github/scripts/docker/run.sh
@@ -0,0 +1,60 @@
+#!/usr/bin/env bash
+set -ex
+
+cd /icefall
+export PYTHONPATH=/icefall:$PYTHONPATH
+python3 -c "import torch; print(torch.__file__)"
+python3 -c "import torchaudio; print(torchaudio.__version__)"
+python3 -c "import icefall; print(icefall.__file__)"
+
+cd egs/librispeech/ASR
+
+# We don't download the LM file since it is so large that it will
+# cause OOM error for CI later.
+mkdir -p download/lm
+pushd download/lm
+wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
+wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
+wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
+ls -lh
+gunzip librispeech-lm-norm.txt.gz
+
+ls -lh
+popd
+
+pushd download/
+wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2
+tar xf LibriSpeech.tar.bz2
+rm LibriSpeech.tar.bz2
+
+cd LibriSpeech
+ln -s train-clean-100 train-clean-360
+ln -s train-other-500 train-other-500
+popd
+
+mkdir -p data/manifests
+
+lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests
+ls -lh data/manifests
+
+./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False
+ls -lh data/fbank
+
+./prepare.sh --stage 5 --stop-stage 6
+
+./zipformer/train.py \
+ --world-size 1 \
+ --num-epochs 1 \
+ --start-epoch 1 \
+ --use-fp16 0 \
+ --exp-dir zipformer/exp-small \
+ --causal 0 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 64,96,96,96,96,96 \
+ --encoder-dim 32,64,64,64,64,64 \
+ --encoder-unmasked-dim 32,32,32,32,32,32 \
+ --base-lr 0.04 \
+ --full-libri 0 \
+ --enable-musan 0 \
+ --max-duration 30 \
+ --print-diagnostics 1
diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml
new file mode 100644
index 000000000..f931f7d09
--- /dev/null
+++ b/.github/workflows/build-cpu-docker.yml
@@ -0,0 +1,75 @@
+name: build-cpu-docker
+on:
+ workflow_dispatch:
+
+concurrency:
+ group: build-cpu-docker-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ build-cpu-docker:
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.8", "3.9", "3.10"]
+ torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ k2-version: ["1.24.4.dev20231220"]
+ kaldifeat-version: ["1.25.3.dev20231221"]
+ version: ["1.0"]
+
+ steps:
+ # refer to https://github.com/actions/checkout
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Free space
+ shell: bash
+ run: |
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+
+ - name: 'Login to GitHub Container Registry'
+ uses: docker/login-action@v2
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build docker Image
+ shell: bash
+ run: |
+ cd .github/scripts/docker
+ torch_version=${{ matrix.torch-version }}
+ if [[ $torch_version == 1.13.0 ]]; then
+ torchaudio_version=0.13.0
+ elif [[ $torch_version == 2.0.0 ]]; then
+ torchaudio_version=2.0.1
+ elif [[ $torch_version == 2.0.1 ]]; then
+ torchaudio_version=2.0.2
+ else
+ torchaudio_version=$torch_version
+ fi
+ echo "torch_version: $torch_version"
+ echo "torchaudio_version: $torchaudio_version"
+
+ version=${{ matrix.version }}
+
+ tag=ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version
+ echo "tag: $tag"
+
+ docker build \
+ -t $tag \
+ --build-arg PYTHON_VERSION=${{ matrix.python-version }} \
+ --build-arg TORCH_VERSION=$torch_version \
+ --build-arg TORCHAUDIO_VERSION=$torchaudio_version \
+ --build-arg K2_VERSION=${{ matrix.k2-version }} \
+ --build-arg KALDIFEAT_VERSION=${{ matrix.kaldifeat-version }} \
+ .
+
+ docker image ls
+ docker push $tag
diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml
new file mode 100644
index 000000000..7c9a28f03
--- /dev/null
+++ b/.github/workflows/train-librispeech.yml
@@ -0,0 +1,56 @@
+name: train librispeech
+on:
+ push:
+ branches:
+ - master
+
+ pull_request:
+ branches:
+ - master
+
+ workflow_dispatch:
+
+concurrency:
+ group: train-librispeech-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ train-librispeech:
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ os: [ubuntu-latest]
+ python-version: ["3.8", "3.9", "3.10"]
+ torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ k2-version: ["1.24.4.dev20231220"]
+ kaldifeat-version: ["1.25.3.dev20231221"]
+ version: ["1.0"]
+
+ steps:
+ # refer to https://github.com/actions/checkout
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Free space
+ shell: bash
+ run: |
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+ echo "pwd: $PWD"
+ echo "github.workspace ${{ github.workspace }}"
+
+ - name: Run the build process with Docker
+ uses: addnab/docker-run-action@v3
+ with:
+ image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ ls -lh /icefall
+
+ /icefall/.github/scripts/docker/run.sh
diff --git a/egs/gigaspeech/ASR/zipformer/my_profile.py b/egs/gigaspeech/ASR/zipformer/my_profile.py
new file mode 120000
index 000000000..3a90b2628
--- /dev/null
+++ b/egs/gigaspeech/ASR/zipformer/my_profile.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/my_profile.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/zipformer/profile.py b/egs/gigaspeech/ASR/zipformer/profile.py
deleted file mode 120000
index c93adbd14..000000000
--- a/egs/gigaspeech/ASR/zipformer/profile.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/zipformer/profile.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py
similarity index 98%
rename from egs/librispeech/ASR/pruned_transducer_stateless/profile.py
rename to egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py
index 09e4a7af4..b844ba613 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/profile.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py
@@ -17,7 +17,7 @@
# limitations under the License.
"""
-Usage: ./pruned_transducer_stateless/profile.py
+Usage: ./pruned_transducer_stateless/my_profile.py
"""
import argparse
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py
similarity index 98%
rename from egs/librispeech/ASR/pruned_transducer_stateless4/profile.py
rename to egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py
index 252bdf060..4bf773918 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/profile.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py
@@ -17,7 +17,7 @@
# limitations under the License.
"""
-Usage: ./pruned_transducer_stateless4/profile.py
+Usage: ./pruned_transducer_stateless4/my_profile.py
"""
import argparse
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py
similarity index 98%
rename from egs/librispeech/ASR/pruned_transducer_stateless7/profile.py
rename to egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py
index 0d308e966..5a068b3b6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/profile.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py
@@ -17,7 +17,7 @@
# limitations under the License.
"""
-Usage: ./pruned_transducer_stateless7/profile.py
+Usage: ./pruned_transducer_stateless7/my_profile.py
"""
import argparse
diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/my_profile.py
similarity index 99%
rename from egs/librispeech/ASR/zipformer/profile.py
rename to egs/librispeech/ASR/zipformer/my_profile.py
index 57f44a90a..ca20956fb 100755
--- a/egs/librispeech/ASR/zipformer/profile.py
+++ b/egs/librispeech/ASR/zipformer/my_profile.py
@@ -17,7 +17,7 @@
# limitations under the License.
"""
-Usage: ./zipformer/profile.py
+Usage: ./zipformer/my_profile.py
"""
import argparse
diff --git a/egs/tedlium3/ASR/zipformer/my_profile.py b/egs/tedlium3/ASR/zipformer/my_profile.py
new file mode 120000
index 000000000..3a90b2628
--- /dev/null
+++ b/egs/tedlium3/ASR/zipformer/my_profile.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/my_profile.py
\ No newline at end of file
diff --git a/egs/tedlium3/ASR/zipformer/profile.py b/egs/tedlium3/ASR/zipformer/profile.py
deleted file mode 120000
index c93adbd14..000000000
--- a/egs/tedlium3/ASR/zipformer/profile.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/zipformer/profile.py
\ No newline at end of file
From e5bb1ae86cc750a51626e9afcb973cc03fa72f86 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sun, 24 Dec 2023 13:40:33 +0800
Subject: [PATCH 023/123] Use the CPU docker in CI to simplify the test code
(#1427)
---
.github/scripts/docker/Dockerfile | 21 ++--
.github/workflows/build-cpu-docker.yml | 8 +-
.github/workflows/test.yml | 139 +++++++++---------------
.github/workflows/train-librispeech.yml | 8 +-
4 files changed, 72 insertions(+), 104 deletions(-)
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index 55c3aa1b9..bbf978d26 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -14,6 +14,7 @@ RUN apt-get update -y && \
ffmpeg \
git \
git-lfs \
+ graphviz \
less \
vim \
&& \
@@ -32,23 +33,23 @@ RUN pip install --no-cache-dir \
k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
+ dill \
+ graphviz \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
- sentencepiece>=0.1.96 \
- tensorboard \
- typeguard \
- dill \
- onnx \
- onnxruntime \
- onnxmltools \
- six \
+ matplotlib \
multi_quantization \
- typeguard \
numpy \
+ onnx \
+ onnxmltools \
+ onnxruntime \
pytest \
- graphviz
+ sentencepiece>=0.1.96 \
+ six \
+ tensorboard \
+ typeguard
# RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
# cd /workspace/icefall && \
diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml
index f931f7d09..b26cd2095 100644
--- a/.github/workflows/build-cpu-docker.yml
+++ b/.github/workflows/build-cpu-docker.yml
@@ -15,10 +15,10 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10"]
- torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
k2-version: ["1.24.4.dev20231220"]
kaldifeat-version: ["1.25.3.dev20231221"]
- version: ["1.0"]
+ version: ["1.1"]
steps:
# refer to https://github.com/actions/checkout
@@ -45,8 +45,12 @@ jobs:
run: |
cd .github/scripts/docker
torch_version=${{ matrix.torch-version }}
+
+ # see https://pytorch.org/audio/stable/installation.html#compatibility-matrix
if [[ $torch_version == 1.13.0 ]]; then
torchaudio_version=0.13.0
+ elif [[ $torch_version == 1.13.1 ]]; then
+ torchaudio_version=0.13.1
elif [[ $torch_version == 2.0.0 ]]; then
torchaudio_version=2.0.1
elif [[ $torch_version == 2.0.1 ]]; then
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 363556bb7..b3fd6f133 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,129 +1,94 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
name: test
on:
push:
branches:
- master
+
pull_request:
branches:
- master
+ workflow_dispatch:
+
concurrency:
group: test-${{ github.ref }}
cancel-in-progress: true
jobs:
test:
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ${{ matrix.os }}
strategy:
+ fail-fast: false
matrix:
os: [ubuntu-latest]
- python-version: ["3.8"]
- torch: ["1.13.0"]
- torchaudio: ["0.13.0"]
- k2-version: ["1.24.3.dev20230719"]
-
- fail-fast: false
+ python-version: ["3.8", "3.9", "3.10"]
+ torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ version: ["1.1"]
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
with:
fetch-depth: 0
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
- with:
- python-version: ${{ matrix.python-version }}
-
- - name: Install libnsdfile and libsox
- if: startsWith(matrix.os, 'ubuntu')
- run: |
- sudo apt update
- sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
- sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all
-
- - name: Install Python dependencies
- run: |
- python3 -m pip install --upgrade pip pytest
- # numpy 1.20.x does not support python 3.6
- pip install numpy==1.19
- pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
-
- pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html
- pip install git+https://github.com/lhotse-speech/lhotse
- # icefall requirements
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- pip install kaldifst
- pip install onnxruntime matplotlib
- pip install -r requirements.txt
-
- - name: Install graphviz
- if: startsWith(matrix.os, 'ubuntu')
+ - name: Free space
shell: bash
run: |
- python3 -m pip install -qq graphviz
- sudo apt-get -qq install graphviz
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+ echo "pwd: $PWD"
+ echo "github.workspace ${{ github.workspace }}"
- name: Run tests
- if: startsWith(matrix.os, 'ubuntu')
- run: |
- ls -lh
- export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
- echo $PYTHONPATH
- pytest -v -s ./test
- # runt tests for conformer ctc
- cd egs/librispeech/ASR/conformer_ctc
- pytest -v -s
+ uses: addnab/docker-run-action@v3
+ with:
+ image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
- cd ../pruned_transducer_stateless
- pytest -v -s
+ pytest -v -s ./test
- cd ../pruned_transducer_stateless2
- pytest -v -s
+ # runt tests for conformer ctc
+ cd egs/librispeech/ASR/conformer_ctc
+ pytest -v -s
- cd ../pruned_transducer_stateless3
- pytest -v -s
+ cd ../pruned_transducer_stateless
+ pytest -v -s
- cd ../pruned_transducer_stateless4
- pytest -v -s
+ cd ../pruned_transducer_stateless2
+ pytest -v -s
- echo $PYTHONPATH
- cd ../pruned_transducer_stateless7
- pytest -v -s
+ cd ../pruned_transducer_stateless3
+ pytest -v -s
- cd ../transducer_stateless
- pytest -v -s
+ cd ../pruned_transducer_stateless4
+ pytest -v -s
- # cd ../transducer
- # pytest -v -s
+ echo $PYTHONPATH
+ cd ../pruned_transducer_stateless7
+ pytest -v -s
- cd ../transducer_stateless2
- pytest -v -s
+ cd ../transducer_stateless
+ pytest -v -s
- cd ../transducer_lstm
- pytest -v -s
+ # cd ../transducer
+ # pytest -v -s
- cd ../zipformer
- pytest -v -s
+ cd ../transducer_stateless2
+ pytest -v -s
+
+ cd ../transducer_lstm
+ pytest -v -s
+
+ cd ../zipformer
+ pytest -v -s
- uses: actions/upload-artifact@v2
with:
diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml
index 7c9a28f03..53a2d5843 100644
--- a/.github/workflows/train-librispeech.yml
+++ b/.github/workflows/train-librispeech.yml
@@ -23,10 +23,8 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10"]
- torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
- k2-version: ["1.24.4.dev20231220"]
- kaldifeat-version: ["1.25.3.dev20231221"]
- version: ["1.0"]
+ torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ version: ["1.1"]
steps:
# refer to https://github.com/actions/checkout
@@ -43,7 +41,7 @@ jobs:
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- - name: Run the build process with Docker
+ - name: Test zipformer/train.py with LibriSpeech
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
From c855a58cfd8628dfe4ef2ffc0ac169d84a8ac0c5 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Mon, 25 Dec 2023 19:41:09 +0800
Subject: [PATCH 024/123] Generate the dependency matrix by code for GitHub
Actions (#1431)
---
.github/scripts/docker/Dockerfile | 2 +
.../scripts/docker/generate_build_matrix.py | 79 ++++++++
.../{docker => librispeech/ASR}/run.sh | 11 +-
.github/scripts/yesno/ASR/run.sh | 86 +++++++++
.github/workflows/build-cpu-docker.yml | 42 ++--
.github/workflows/run-yesno-recipe.yml | 182 ++++--------------
.github/workflows/test.yml | 27 ++-
.github/workflows/train-librispeech.yml | 33 +++-
8 files changed, 279 insertions(+), 183 deletions(-)
create mode 100755 .github/scripts/docker/generate_build_matrix.py
rename .github/scripts/{docker => librispeech/ASR}/run.sh (87%)
create mode 100755 .github/scripts/yesno/ASR/run.sh
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index bbf978d26..f75d74854 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -31,10 +31,12 @@ LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN pip install --no-cache-dir \
torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \
k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \
+ \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
dill \
graphviz \
+ kaldi-decoder \
kaldi_native_io \
kaldialign \
kaldifst \
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
new file mode 100755
index 000000000..4e494d810
--- /dev/null
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
+
+
+import json
+
+
+def version_gt(a, b):
+ a_major, a_minor = a.split(".")[:2]
+ b_major, b_minor = b.split(".")[:2]
+ if a_major > b_major:
+ return True
+
+ if a_major == b_major and a_minor > b_minor:
+ return True
+
+ return False
+
+
+def version_ge(a, b):
+ a_major, a_minor = a.split(".")[:2]
+ b_major, b_minor = b.split(".")[:2]
+ if a_major > b_major:
+ return True
+
+ if a_major == b_major and a_minor >= b_minor:
+ return True
+
+ return False
+
+
+def get_torchaudio_version(torch_version):
+ if torch_version == "1.13.0":
+ return "0.13.0"
+ elif torch_version == "1.13.1":
+ return "0.13.1"
+ elif torch_version == "2.0.0":
+ return "2.0.1"
+ elif torch_version == "2.0.1":
+ return "2.0.2"
+ else:
+ return torch_version
+
+
+def get_matrix():
+ k2_version = "1.24.4.dev20231220"
+ kaldifeat_version = "1.25.3.dev20231221"
+ version = "1.1"
+ python_version = ["3.8", "3.9", "3.10", "3.11"]
+ torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+
+ matrix = []
+ for p in python_version:
+ for t in torch_version:
+ # torchaudio <= 1.13.x supports only python <= 3.10
+
+ if version_gt(p, "3.10") and not version_gt(t, "2.0"):
+ continue
+
+ matrix.append(
+ {
+ "k2-version": k2_version,
+ "kaldifeat-version": kaldifeat_version,
+ "version": version,
+ "python-version": p,
+ "torch-version": t,
+ "torchaudio-version": get_torchaudio_version(t),
+ }
+ )
+ return matrix
+
+
+def main():
+ matrix = get_matrix()
+ print(json.dumps({"include": matrix}))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/.github/scripts/docker/run.sh b/.github/scripts/librispeech/ASR/run.sh
similarity index 87%
rename from .github/scripts/docker/run.sh
rename to .github/scripts/librispeech/ASR/run.sh
index aeb80b330..641d59458 100755
--- a/.github/scripts/docker/run.sh
+++ b/.github/scripts/librispeech/ASR/run.sh
@@ -1,11 +1,12 @@
#!/usr/bin/env bash
+
set -ex
-cd /icefall
-export PYTHONPATH=/icefall:$PYTHONPATH
-python3 -c "import torch; print(torch.__file__)"
-python3 -c "import torchaudio; print(torchaudio.__version__)"
-python3 -c "import icefall; print(icefall.__file__)"
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
cd egs/librispeech/ASR
diff --git a/.github/scripts/yesno/ASR/run.sh b/.github/scripts/yesno/ASR/run.sh
new file mode 100755
index 000000000..05c8fbac9
--- /dev/null
+++ b/.github/scripts/yesno/ASR/run.sh
@@ -0,0 +1,86 @@
+#!/usr/bin/env bash
+
+set -ex
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/yesno/ASR
+
+log "data preparation"
+./prepare.sh
+
+log "training"
+python3 ./tdnn/train.py
+
+log "decoding"
+python3 ./tdnn/decode.py
+
+log "export to pretrained.pt"
+
+python3 ./tdnn/export.py --epoch 14 --avg 2
+
+python3 ./tdnn/pretrained.py \
+ --checkpoint ./tdnn/exp/pretrained.pt \
+ --HLG ./data/lang_phone/HLG.pt \
+ --words-file ./data/lang_phone/words.txt \
+ download/waves_yesno/0_0_0_1_0_0_0_1.wav \
+ download/waves_yesno/0_0_1_0_0_0_1_0.wav
+
+log "Test exporting to torchscript"
+python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
+
+python3 ./tdnn/jit_pretrained.py \
+ --nn-model ./tdnn/exp/cpu_jit.pt \
+ --HLG ./data/lang_phone/HLG.pt \
+ --words-file ./data/lang_phone/words.txt \
+ download/waves_yesno/0_0_0_1_0_0_0_1.wav \
+ download/waves_yesno/0_0_1_0_0_0_1_0.wav
+
+log "Test exporting to onnx"
+python3 ./tdnn/export_onnx.py --epoch 14 --avg 2
+
+log "Test float32 model"
+python3 ./tdnn/onnx_pretrained.py \
+ --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
+ --HLG ./data/lang_phone/HLG.pt \
+ --words-file ./data/lang_phone/words.txt \
+ download/waves_yesno/0_0_0_1_0_0_0_1.wav \
+ download/waves_yesno/0_0_1_0_0_0_1_0.wav
+
+log "Test int8 model"
+python3 ./tdnn/onnx_pretrained.py \
+ --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
+ --HLG ./data/lang_phone/HLG.pt \
+ --words-file ./data/lang_phone/words.txt \
+ download/waves_yesno/0_0_0_1_0_0_0_1.wav \
+ download/waves_yesno/0_0_1_0_0_0_1_0.wav
+
+log "Test decoding with H"
+python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
+
+python3 ./tdnn/jit_pretrained_decode_with_H.py \
+ --nn-model ./tdnn/exp/cpu_jit.pt \
+ --H ./data/lang_phone/H.fst \
+ --tokens ./data/lang_phone/tokens.txt \
+ ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
+ ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
+ ./download/waves_yesno/0_0_1_0_0_1_1_1.wav
+
+log "Test decoding with HL"
+python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
+
+python3 ./tdnn/jit_pretrained_decode_with_HL.py \
+ --nn-model ./tdnn/exp/cpu_jit.pt \
+ --HL ./data/lang_phone/HL.fst \
+ --words ./data/lang_phone/words.txt \
+ ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
+ ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
+ ./download/waves_yesno/0_0_1_0_0_1_1_1.wav
+
+log "Show generated files"
+ls -lh tdnn/exp
+ls -lh data/lang_phone
diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml
index b26cd2095..c5d5aaeb6 100644
--- a/.github/workflows/build-cpu-docker.yml
+++ b/.github/workflows/build-cpu-docker.yml
@@ -7,18 +7,31 @@ concurrency:
cancel-in-progress: true
jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
build-cpu-docker:
+ needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
- runs-on: ${{ matrix.os }}
+ runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- os: [ubuntu-latest]
- python-version: ["3.8", "3.9", "3.10"]
- torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
- k2-version: ["1.24.4.dev20231220"]
- kaldifeat-version: ["1.25.3.dev20231221"]
- version: ["1.1"]
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
# refer to https://github.com/actions/checkout
@@ -45,25 +58,14 @@ jobs:
run: |
cd .github/scripts/docker
torch_version=${{ matrix.torch-version }}
+ torchaudio_version=${{ matrix.torchaudio-version }}
- # see https://pytorch.org/audio/stable/installation.html#compatibility-matrix
- if [[ $torch_version == 1.13.0 ]]; then
- torchaudio_version=0.13.0
- elif [[ $torch_version == 1.13.1 ]]; then
- torchaudio_version=0.13.1
- elif [[ $torch_version == 2.0.0 ]]; then
- torchaudio_version=2.0.1
- elif [[ $torch_version == 2.0.1 ]]; then
- torchaudio_version=2.0.2
- else
- torchaudio_version=$torch_version
- fi
echo "torch_version: $torch_version"
echo "torchaudio_version: $torchaudio_version"
version=${{ matrix.version }}
- tag=ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version
+ tag=ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version
echo "tag: $tag"
docker build \
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml
index 9ac848535..a99811815 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/run-yesno-recipe.yml
@@ -20,166 +20,60 @@ on:
push:
branches:
- master
+ - refactor-ci
+
pull_request:
branches:
- master
+ workflow_dispatch:
+
concurrency:
group: run-yesno-recipe-${{ github.ref }}
cancel-in-progress: true
jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
run-yesno-recipe:
- runs-on: ${{ matrix.os }}
+ needs: generate_build_matrix
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ubuntu-latest
strategy:
- matrix:
- # os: [ubuntu-latest, macos-10.15]
- # TODO: enable macOS for CPU testing
- os: [ubuntu-latest]
- python-version: [3.8]
fail-fast: false
+ matrix:
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v4
with:
fetch-depth: 0
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
+ - name: Run the yesno recipe
+ uses: addnab/docker-run-action@v3
with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
- - name: Install libnsdfile and libsox
- if: startsWith(matrix.os, 'ubuntu')
- run: |
- sudo apt update
- sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg
- sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- pip install --no-deps --force-reinstall k2==1.24.4.dev20231021+cpu.torch1.13.1 -f https://k2-fsa.github.io/k2/cpu.html
- pip install kaldifeat==1.25.1.dev20231022+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html
-
- - name: Run yesno recipe
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- export PYTHONPATH=$PWD:$PYTHONPATH
- echo $PYTHONPATH
-
- cd egs/yesno/ASR
- ./prepare.sh
- python3 ./tdnn/train.py
- python3 ./tdnn/decode.py
-
- - name: Test exporting to pretrained.pt
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- export PYTHONPATH=$PWD:$PYTHONPATH
- echo $PYTHONPATH
-
- cd egs/yesno/ASR
- python3 ./tdnn/export.py --epoch 14 --avg 2
-
- python3 ./tdnn/pretrained.py \
- --checkpoint ./tdnn/exp/pretrained.pt \
- --HLG ./data/lang_phone/HLG.pt \
- --words-file ./data/lang_phone/words.txt \
- download/waves_yesno/0_0_0_1_0_0_0_1.wav \
- download/waves_yesno/0_0_1_0_0_0_1_0.wav
-
- - name: Test exporting to torchscript
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- export PYTHONPATH=$PWD:$PYTHONPATH
- echo $PYTHONPATH
-
- cd egs/yesno/ASR
- python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
-
- python3 ./tdnn/jit_pretrained.py \
- --nn-model ./tdnn/exp/cpu_jit.pt \
- --HLG ./data/lang_phone/HLG.pt \
- --words-file ./data/lang_phone/words.txt \
- download/waves_yesno/0_0_0_1_0_0_0_1.wav \
- download/waves_yesno/0_0_1_0_0_0_1_0.wav
-
- - name: Test exporting to onnx
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- export PYTHONPATH=$PWD:$PYTHONPATH
- echo $PYTHONPATH
-
- cd egs/yesno/ASR
- python3 ./tdnn/export_onnx.py --epoch 14 --avg 2
-
- echo "Test float32 model"
- python3 ./tdnn/onnx_pretrained.py \
- --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
- --HLG ./data/lang_phone/HLG.pt \
- --words-file ./data/lang_phone/words.txt \
- download/waves_yesno/0_0_0_1_0_0_0_1.wav \
- download/waves_yesno/0_0_1_0_0_0_1_0.wav
-
-
- echo "Test int8 model"
- python3 ./tdnn/onnx_pretrained.py \
- --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \
- --HLG ./data/lang_phone/HLG.pt \
- --words-file ./data/lang_phone/words.txt \
- download/waves_yesno/0_0_0_1_0_0_0_1.wav \
- download/waves_yesno/0_0_1_0_0_0_1_0.wav
-
- - name: Test decoding with H
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- export PYTHONPATH=$PWD:$PYTHONPATH
- echo $PYTHONPATH
-
- cd egs/yesno/ASR
- python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
-
- python3 ./tdnn/jit_pretrained_decode_with_H.py \
- --nn-model ./tdnn/exp/cpu_jit.pt \
- --H ./data/lang_phone/H.fst \
- --tokens ./data/lang_phone/tokens.txt \
- ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
- ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
- ./download/waves_yesno/0_0_1_0_0_1_1_1.wav
-
- - name: Test decoding with HL
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- export PYTHONPATH=$PWD:$PYTHONPATH
- echo $PYTHONPATH
-
- cd egs/yesno/ASR
- python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1
-
- python3 ./tdnn/jit_pretrained_decode_with_HL.py \
- --nn-model ./tdnn/exp/cpu_jit.pt \
- --HL ./data/lang_phone/HL.fst \
- --words ./data/lang_phone/words.txt \
- ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \
- ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \
- ./download/waves_yesno/0_0_1_0_0_1_1_1.wav
-
- - name: Show generated files
- shell: bash
- working-directory: ${{github.workspace}}
- run: |
- cd egs/yesno/ASR
- ls -lh tdnn/exp
- ls -lh data/lang_phone
+ .github/scripts/yesno/ASR/run.sh
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index b3fd6f133..659681b37 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -16,16 +16,31 @@ concurrency:
cancel-in-progress: true
jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
test:
+ needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
- runs-on: ${{ matrix.os }}
+ runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- os: [ubuntu-latest]
- python-version: ["3.8", "3.9", "3.10"]
- torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
- version: ["1.1"]
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
@@ -44,7 +59,7 @@ jobs:
- name: Run tests
uses: addnab/docker-run-action@v3
with:
- image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml
index 53a2d5843..79002a881 100644
--- a/.github/workflows/train-librispeech.yml
+++ b/.github/workflows/train-librispeech.yml
@@ -15,16 +15,31 @@ concurrency:
cancel-in-progress: true
jobs:
+ generate_build_matrix:
+ if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
train-librispeech:
+ needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
- runs-on: ${{ matrix.os }}
+ runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- os: [ubuntu-latest]
- python-version: ["3.8", "3.9", "3.10"]
- torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
- version: ["1.1"]
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
# refer to https://github.com/actions/checkout
@@ -44,11 +59,13 @@ jobs:
- name: Test zipformer/train.py with LibriSpeech
uses: addnab/docker-run-action@v3
with:
- image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
- ls -lh /icefall
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
- /icefall/.github/scripts/docker/run.sh
+ .github/scripts/librispeech/ASR/run.sh
From ddd71313179a1565ea4d9e2e37546c3ef6b98d90 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?=
<53865510+ahazned@users.noreply.github.com>
Date: Mon, 25 Dec 2023 14:44:07 +0300
Subject: [PATCH 025/123] Update TTS export-onnx.py scripts for handling
variable token counts (#1430)
---
egs/ljspeech/TTS/vits/export-onnx.py | 6 +++++-
egs/vctk/TTS/vits/export-onnx.py | 6 +++++-
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
index 36a9de27f..f82f9dbe9 100755
--- a/egs/ljspeech/TTS/vits/export-onnx.py
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -149,6 +149,7 @@ class OnnxModel(nn.Module):
def export_model_onnx(
model: nn.Module,
model_filename: str,
+ vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@@ -165,10 +166,12 @@ def export_model_onnx(
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
+ vocab_size:
+ Number of tokens used in training.
opset_version:
The opset version to use.
"""
- tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
@@ -244,6 +247,7 @@ def main():
export_model_onnx(
model,
model_filename,
+ params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py
index 667ac284b..80d155626 100755
--- a/egs/vctk/TTS/vits/export-onnx.py
+++ b/egs/vctk/TTS/vits/export-onnx.py
@@ -159,6 +159,7 @@ class OnnxModel(nn.Module):
def export_model_onnx(
model: nn.Module,
model_filename: str,
+ vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@@ -175,10 +176,12 @@ def export_model_onnx(
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
+ vocab_size:
+ Number of tokens used in training.
opset_version:
The opset version to use.
"""
- tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
@@ -261,6 +264,7 @@ def main():
export_model_onnx(
model,
model_filename,
+ params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
From 835a92eba51a939c6b4a069a53cc1e3ddeabd9a5 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Mon, 25 Dec 2023 20:23:56 +0800
Subject: [PATCH 026/123] Add doc about how to use the CPU-only docker images
(#1432)
---
docs/source/docker/intro.rst | 46 ++++++++++++++++++++++++++++++++----
1 file changed, 41 insertions(+), 5 deletions(-)
diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst
index 9ead0df00..cbd300d9b 100644
--- a/docs/source/docker/intro.rst
+++ b/docs/source/docker/intro.rst
@@ -20,7 +20,11 @@ We describe the following items in this section:
View available tags
===================
-You can use the following command to view available tags:
+CUDA-enabled docker images
+--------------------------
+
+You can use the following command to view available tags for CUDA-enabled
+docker images:
.. code-block:: bash
@@ -43,8 +47,25 @@ which will give you something like below:
Please select an appropriate combination of `torch`_ and CUDA.
-Download a docker image
-=======================
+CPU-only docker images
+----------------------
+
+To view CPU-only docker images, please visit ``_
+for available tags.
+
+You can select different combinations of ``Python`` and ``torch``. For instance,
+to select ``Python 3.8`` and ``torch 2.1.2``, you can use the following tag
+
+.. code-block:: bash
+
+ cpu-py3.8-torch2.1.2-v1.1
+
+where ``v1.1`` is the current version of the docker image. You may see
+``ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.2`` or some other versions.
+We recommend that you always use the latest version.
+
+Download a docker image (CUDA)
+==============================
Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use
the following command to download it:
@@ -53,6 +74,16 @@ the following command to download it:
sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6
+Download a docker image (CPU)
+==============================
+
+Suppose that you select the tag ``cpu-py3.8-torch2.1.2-v1.1``, you can use
+the following command to download it:
+
+.. code-block:: bash
+
+ sudo docker pull ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.1
+
Run a docker image with GPU
===========================
@@ -65,7 +96,7 @@ Run a docker image with CPU
.. code-block:: bash
- sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash
+ sudo docker run --rm -it ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.1 /bin/bash
Run yesno within a docker container
===================================
@@ -74,8 +105,13 @@ After starting the container, the following interface is presented:
.. code-block:: bash
+ # GPU-enabled docker
root@60c947eac59c:/workspace/icefall#
+ # CPU-only docker
+ root@60c947eac59c:# mkdir /workspace; git clone https://github.com/k2-fsa/icefall
+ root@60c947eac59c:# export PYTHONPATH=/workspace/icefall:$PYTHONPATH
+
It shows the current user is ``root`` and the current working directory
is ``/workspace/icefall``.
@@ -107,7 +143,7 @@ to switch to the ``yesno`` recipe and run
.. hint::
- If you are running without GPU, it may report the following error:
+ If you are running without GPU with a GPU-enabled docker, it may report the following error:
.. code-block:: bash
From db52fe2349df0e07e931accb0cf1e63fec389fb7 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Tue, 26 Dec 2023 20:29:43 +0800
Subject: [PATCH 027/123] Refactor CI test for aishell (#1435)
---
.github/scripts/aishell/ASR/run.sh | 274 ++++++++++++++++++
.github/scripts/docker/Dockerfile | 1 +
.../scripts/docker/generate_build_matrix.py | 2 +-
...pruned-transducer-stateless3-2022-06-20.sh | 87 ------
.../run-aishell-zipformer-2023-10-24.sh | 103 -------
...transducer-stateless-modified-2-aishell.sh | 48 ---
...d-transducer-stateless-modified-aishell.sh | 48 ---
.github/workflows/aishell.yml | 81 ++++++
.github/workflows/run-aishell-2022-06-20.yml | 123 --------
.../run-aishell-zipformer-2023-10-24.yml | 95 ------
...ransducer-stateless-modified-2-aishell.yml | 80 -----
...-transducer-stateless-modified-aishell.yml | 80 -----
.../{run-yesno-recipe.yml => yesno.yml} | 23 +-
13 files changed, 360 insertions(+), 685 deletions(-)
create mode 100755 .github/scripts/aishell/ASR/run.sh
delete mode 100755 .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
delete mode 100755 .github/scripts/run-aishell-zipformer-2023-10-24.sh
delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
create mode 100644 .github/workflows/aishell.yml
delete mode 100644 .github/workflows/run-aishell-2022-06-20.yml
delete mode 100644 .github/workflows/run-aishell-zipformer-2023-10-24.yml
delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
rename .github/workflows/{run-yesno-recipe.yml => yesno.yml} (69%)
diff --git a/.github/scripts/aishell/ASR/run.sh b/.github/scripts/aishell/ASR/run.sh
new file mode 100755
index 000000000..4d912fa76
--- /dev/null
+++ b/.github/scripts/aishell/ASR/run.sh
@@ -0,0 +1,274 @@
+#!/usr/bin/env bash
+
+set -ex
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+cd egs/aishell/ASR
+
+function download_test_dev_manifests() {
+ git lfs install
+
+ fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
+ log "Downloading pre-commputed fbank from $fbank_url"
+
+ git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
+ ln -s $PWD/aishell-test-dev-manifests/data .
+}
+
+function test_transducer_stateless3_2022_06_20() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
+ log "Downloading pre-trained model from $repo_url"
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt
+ popd
+
+ log "test greedy_search with pretrained.py"
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+
+ log "test beam search with pretrained.py"
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+
+ echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
+ echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
+ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
+ mkdir -p pruned_transducer_stateless3/exp
+ ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
+ ln -s $PWD/$repo/data/lang_char data/
+
+ ls -lh data
+ ls -lh pruned_transducer_stateless3/exp
+
+ log "Decoding test and dev"
+
+ # use a small value for decoding with CPU
+ max_duration=100
+
+ for method in greedy_search fast_beam_search modified_beam_search; do
+ log "Decoding with $method"
+
+ ./pruned_transducer_stateless3/decode.py \
+ --decoding-method $method \
+ --epoch 999 \
+ --avg 1 \
+ --max-duration $max_duration \
+ --exp-dir pruned_transducer_stateless3/exp
+ done
+
+ rm pruned_transducer_stateless3/exp/*.pt
+ fi
+
+ rm -rf $repo
+}
+
+function test_zipformer_large_2023_10_24() {
+ log "CI testing large model"
+ repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/
+ log "Downloading pre-trained model from $repo_url"
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for method in modified_beam_search greedy_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --method $method \
+ --context-size 1 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_char/tokens.txt \
+ --num-encoder-layers 2,2,4,5,4,2 \
+ --feedforward-dim 512,768,1536,2048,1536,768 \
+ --encoder-dim 192,256,512,768,512,256 \
+ --encoder-unmasked-dim 192,192,256,320,256,192 \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+ rm -rf $repo
+}
+
+function test_zipformer_2023_10_24() {
+ repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/
+ log "Downloading pre-trained model from $repo_url"
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+
+ for method in modified_beam_search greedy_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --method $method \
+ --context-size 1 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_char/tokens.txt \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+ rm -rf $repo
+}
+
+function test_zipformer_small_2023_10_24() {
+ log "CI testing small model"
+ repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/
+ log "Downloading pre-trained model from $repo_url"
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+
+ for method in modified_beam_search greedy_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --method $method \
+ --context-size 1 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_char/tokens.txt \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192 \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+ rm -rf $repo
+}
+
+function test_transducer_stateless_modified_2022_03_01() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_modified/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+
+ for method in modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless_modified/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+ rm -rf $repo
+}
+
+function test_transducer_stateless_modified_2_2022_03_01() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_modified-2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+
+ for method in modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless_modified-2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_char \
+ $repo/test_wavs/BAC009S0764W0121.wav \
+ $repo/test_wavs/BAC009S0764W0122.wav \
+ $repo/test_wavs/BAC009S0764W0123.wav
+ done
+ rm -rf $repo
+}
+
+download_test_dev_manifests
+test_transducer_stateless3_2022_06_20
+test_zipformer_large_2023_10_24
+test_zipformer_2023_10_24
+test_zipformer_small_2023_10_24
+test_transducer_stateless_modified_2022_03_01
+test_transducer_stateless_modified_2_2022_03_01
+
+ls -lh
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index f75d74854..f6a088af1 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -16,6 +16,7 @@ RUN apt-get update -y && \
git-lfs \
graphviz \
less \
+ tree \
vim \
&& \
apt-get clean && \
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index 4e494d810..bdde97647 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version):
def get_matrix():
k2_version = "1.24.4.dev20231220"
kaldifeat_version = "1.25.3.dev20231221"
- version = "1.1"
+ version = "1.2"
python_version = ["3.8", "3.9", "3.10", "3.11"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
deleted file mode 100755
index c3640cfde..000000000
--- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
+++ /dev/null
@@ -1,87 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/aishell/ASR
-
-git lfs install
-
-fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
-log "Downloading pre-commputed fbank from $fbank_url"
-
-git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
-ln -s $PWD/aishell-test-dev-manifests/data .
-
-repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
-log "Downloading pre-trained model from $repo_url"
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt
-popd
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless3/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_char \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless3/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_char \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless3/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_char data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless3/exp
-
- log "Decoding test and dev"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless3/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless3/exp
- done
-
- rm pruned_transducer_stateless3/exp/*.pt
-fi
diff --git a/.github/scripts/run-aishell-zipformer-2023-10-24.sh b/.github/scripts/run-aishell-zipformer-2023-10-24.sh
deleted file mode 100755
index 865e29799..000000000
--- a/.github/scripts/run-aishell-zipformer-2023-10-24.sh
+++ /dev/null
@@ -1,103 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/aishell/ASR
-
-git lfs install
-
-fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
-log "Downloading pre-commputed fbank from $fbank_url"
-
-git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
-ln -s $PWD/aishell-test-dev-manifests/data .
-
-log "======================="
-log "CI testing large model"
-repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/
-log "Downloading pre-trained model from $repo_url"
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for method in modified_beam_search greedy_search fast_beam_search; do
- log "$method"
-
- ./zipformer/pretrained.py \
- --method $method \
- --context-size 1 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_char/tokens.txt \
- --num-encoder-layers 2,2,4,5,4,2 \
- --feedforward-dim 512,768,1536,2048,1536,768 \
- --encoder-dim 192,256,512,768,512,256 \
- --encoder-unmasked-dim 192,192,256,320,256,192 \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
-log "======================="
-log "CI testing medium model"
-repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/
-log "Downloading pre-trained model from $repo_url"
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-
-for method in modified_beam_search greedy_search fast_beam_search; do
- log "$method"
-
- ./zipformer/pretrained.py \
- --method $method \
- --context-size 1 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_char/tokens.txt \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
-
-log "======================="
-log "CI testing small model"
-repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/
-log "Downloading pre-trained model from $repo_url"
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-
-for method in modified_beam_search greedy_search fast_beam_search; do
- log "$method"
-
- ./zipformer/pretrained.py \
- --method $method \
- --context-size 1 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_char/tokens.txt \
- --num-encoder-layers 2,2,2,2,2,2 \
- --feedforward-dim 512,768,768,768,768,768 \
- --encoder-dim 192,256,256,256,256,256 \
- --encoder-unmasked-dim 192,192,192,192,192,192 \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
deleted file mode 100755
index 0644d9be0..000000000
--- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/aishell/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./transducer_stateless_modified-2/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_char \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
-for method in modified_beam_search beam_search; do
- log "$method"
-
- ./transducer_stateless_modified-2/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_char \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
deleted file mode 100755
index 79fb64311..000000000
--- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/aishell/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./transducer_stateless_modified/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_char \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
-
-for method in modified_beam_search beam_search; do
- log "$method"
-
- ./transducer_stateless_modified/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_char \
- $repo/test_wavs/BAC009S0764W0121.wav \
- $repo/test_wavs/BAC009S0764W0122.wav \
- $repo/test_wavs/BAC009S0764W0123.wav
-done
diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml
new file mode 100644
index 000000000..136e117bd
--- /dev/null
+++ b/.github/workflows/aishell.yml
@@ -0,0 +1,81 @@
+name: aishell
+
+on:
+ push:
+ branches:
+ - master
+
+ pull_request:
+ branches:
+ - master
+
+ workflow_dispatch:
+
+ schedule:
+ # minute (0-59)
+ # hour (0-23)
+ # day of the month (1-31)
+ # month (1-12)
+ # day of the week (0-6)
+ # nightly build at 15:50 UTC time every day
+ - cron: "50 15 * * *"
+
+concurrency:
+ group: aishell-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ generate_build_matrix:
+ if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'aishell' || github.event_name == 'schedule')
+
+ # see https://github.com/pytorch/pytorch/pull/50633
+ runs-on: ubuntu-latest
+ outputs:
+ matrix: ${{ steps.set-matrix.outputs.matrix }}
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Generating build matrix
+ id: set-matrix
+ run: |
+ # outputting for debugging purposes
+ python ./.github/scripts/docker/generate_build_matrix.py
+ MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
+ echo "::set-output name=matrix::${MATRIX}"
+ aishell:
+ needs: generate_build_matrix
+ name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Free space
+ shell: bash
+ run: |
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+ echo "pwd: $PWD"
+ echo "github.workspace ${{ github.workspace }}"
+
+ - name: Run aishell tests
+ uses: addnab/docker-run-action@v3
+ with:
+ image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
+ options: |
+ --volume ${{ github.workspace }}/:/icefall
+ shell: bash
+ run: |
+ export PYTHONPATH=/icefall:$PYTHONPATH
+ cd /icefall
+ git config --global --add safe.directory /icefall
+
+ .github/scripts/aishell/ASR/run.sh
diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml
deleted file mode 100644
index 53fcb2c03..000000000
--- a/.github/workflows/run-aishell-2022-06-20.yml
+++ /dev/null
@@ -1,123 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-aishell-2022-06-20
-# pruned RNN-T + reworked model with random combiner
-# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_aishell_2022_06_20-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_aishell_2022_06_20:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
-
- - name: Display decoding results for aishell pruned_transducer_stateless3
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/aishell/ASR/
- tree ./pruned_transducer_stateless3/exp
-
- cd pruned_transducer_stateless3
- echo "results for pruned_transducer_stateless3"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
-
- - name: Upload decoding results for aishell pruned_transducer_stateless3
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20
- path: egs/aishell/ASR/pruned_transducer_stateless3/exp/
diff --git a/.github/workflows/run-aishell-zipformer-2023-10-24.yml b/.github/workflows/run-aishell-zipformer-2023-10-24.yml
deleted file mode 100644
index f2fb44a5f..000000000
--- a/.github/workflows/run-aishell-zipformer-2023-10-24.yml
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2023 Zengrui Jin (Xiaomi Corp.)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-aishell-zipformer-2023-10-24
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_aishell_zipformer_2023_10_24-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_aishell_zipformer_2023_10_24:
- if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-aishell-zipformer-2023-10-24.sh
-
-
\ No newline at end of file
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
deleted file mode 100644
index ce6d6f92d..000000000
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml
+++ /dev/null
@@ -1,80 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-trandsucer-stateless-modified-2-aishell
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
-concurrency:
- group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_transducer_stateless_modified_2_aishell:
- if: github.event.label.name == 'ready' || github.event_name == 'push'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Inference with pre-trained model
- shell: bash
- run: |
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh
diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
deleted file mode 100644
index f0cebd94a..000000000
--- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml
+++ /dev/null
@@ -1,80 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-trandsucer-stateless-modified-aishell
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
-concurrency:
- group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_transducer_stateless_modified_aishell:
- if: github.event.label.name == 'ready' || github.event_name == 'push'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Inference with pre-trained model
- shell: bash
- run: |
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh
diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/yesno.yml
similarity index 69%
rename from .github/workflows/run-yesno-recipe.yml
rename to .github/workflows/yesno.yml
index a99811815..182300dfa 100644
--- a/.github/workflows/run-yesno-recipe.yml
+++ b/.github/workflows/yesno.yml
@@ -1,26 +1,9 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-yesno-recipe
+name: yesno
on:
push:
branches:
- master
- - refactor-ci
pull_request:
branches:
@@ -29,7 +12,7 @@ on:
workflow_dispatch:
concurrency:
- group: run-yesno-recipe-${{ github.ref }}
+ group: yesno-${{ github.ref }}
cancel-in-progress: true
jobs:
@@ -50,7 +33,7 @@ jobs:
python ./.github/scripts/docker/generate_build_matrix.py
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
echo "::set-output name=matrix::${MATRIX}"
- run-yesno-recipe:
+ yesno:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
From 140e6381ad4699ce919705f59240fad58c0b1bb6 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Wed, 27 Dec 2023 13:21:14 +0800
Subject: [PATCH 028/123] Refactor CI tests for librispeech (#1436)
---
.github/scripts/aishell/ASR/run.sh | 73 +-
.github/scripts/librispeech/ASR/run.sh | 1624 ++++++++++++++++-
...n-librispeech-conformer-ctc3-2022-11-28.sh | 122 --
...-pruned-transducer-stateless-2022-03-12.sh | 77 -
...pruned-transducer-stateless2-2022-04-29.sh | 86 -
...pruned-transducer-stateless3-2022-04-29.sh | 85 -
...pruned-transducer-stateless3-2022-05-13.sh | 123 --
...pruned-transducer-stateless5-2022-05-13.sh | 100 -
...pruned-transducer-stateless7-2022-11-11.sh | 106 --
...ed-transducer-stateless7-ctc-2022-12-01.sh | 150 --
...transducer-stateless7-ctc-bs-2023-01-29.sh | 147 --
...nsducer-stateless7-streaming-2022-12-29.sh | 148 --
...pruned-transducer-stateless8-2022-11-14.sh | 115 --
...pruned-transducer-stateless2-2022-06-26.sh | 101 -
...rispeech-streaming-zipformer-2023-05-18.sh | 116 --
...speech-transducer-stateless2-2022-04-19.sh | 77 -
.../run-librispeech-zipformer-2023-05-18.sh | 94 -
...un-librispeech-zipformer-ctc-2023-06-14.sh | 117 --
...un-librispeech-zipformer-mmi-2022-12-08.sh | 102 --
.github/scripts/run-pre-trained-ctc.sh | 240 ---
...d-transducer-stateless-librispeech-100h.sh | 77 -
...d-transducer-stateless-librispeech-960h.sh | 77 -
.../run-pre-trained-transducer-stateless.sh | 77 -
.github/scripts/run-pre-trained-transducer.sh | 33 -
.github/workflows/aishell.yml | 11 +-
...{train-librispeech.yml => librispeech.yml} | 6 +-
.../workflows/run-librispeech-2022-03-12.yml | 159 --
.../workflows/run-librispeech-2022-04-29.yml | 185 --
.../workflows/run-librispeech-2022-05-13.yml | 159 --
.../run-librispeech-2022-11-11-stateless7.yml | 159 --
.../run-librispeech-2022-11-14-stateless8.yml | 159 --
...-librispeech-2022-12-01-stateless7-ctc.yml | 163 --
...n-librispeech-2022-12-08-zipformer-mmi.yml | 167 --
...speech-2022-12-29-stateless7-streaming.yml | 172 --
...brispeech-2023-01-29-stateless7-ctc-bs.yml | 163 --
...-librispeech-conformer-ctc3-2022-11-28.yml | 155 --
...runed-transducer-stateless3-2022-05-13.yml | 157 --
...aming-transducer-stateless2-2022-06-26.yml | 159 --
...ispeech-streaming-zipformer-2023-05-18.yml | 174 --
...peech-transducer-stateless2-2022-04-19.yml | 159 --
.../run-librispeech-zipformer-2023-05-18.yml | 159 --
...n-librispeech-zipformer-ctc-2023-06-14.yml | 155 --
.github/workflows/run-pretrained-ctc.yml | 87 -
...-transducer-stateless-librispeech-100h.yml | 158 --
...r-stateless-librispeech-multi-datasets.yml | 158 --
.../run-pretrained-transducer-stateless.yml | 158 --
.../workflows/run-pretrained-transducer.yml | 80 -
47 files changed, 1658 insertions(+), 5671 deletions(-)
delete mode 100755 .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
delete mode 100755 .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
delete mode 100755 .github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
delete mode 100755 .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
delete mode 100755 .github/scripts/run-librispeech-zipformer-2023-05-18.sh
delete mode 100755 .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
delete mode 100755 .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
delete mode 100755 .github/scripts/run-pre-trained-ctc.sh
delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless.sh
delete mode 100755 .github/scripts/run-pre-trained-transducer.sh
rename .github/workflows/{train-librispeech.yml => librispeech.yml} (95%)
delete mode 100644 .github/workflows/run-librispeech-2022-03-12.yml
delete mode 100644 .github/workflows/run-librispeech-2022-04-29.yml
delete mode 100644 .github/workflows/run-librispeech-2022-05-13.yml
delete mode 100644 .github/workflows/run-librispeech-2022-11-11-stateless7.yml
delete mode 100644 .github/workflows/run-librispeech-2022-11-14-stateless8.yml
delete mode 100644 .github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
delete mode 100644 .github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
delete mode 100644 .github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml
delete mode 100644 .github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml
delete mode 100644 .github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
delete mode 100644 .github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
delete mode 100644 .github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
delete mode 100644 .github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml
delete mode 100644 .github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
delete mode 100644 .github/workflows/run-librispeech-zipformer-2023-05-18.yml
delete mode 100644 .github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml
delete mode 100644 .github/workflows/run-pretrained-ctc.yml
delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
delete mode 100644 .github/workflows/run-pretrained-transducer-stateless.yml
delete mode 100644 .github/workflows/run-pretrained-transducer.yml
diff --git a/.github/scripts/aishell/ASR/run.sh b/.github/scripts/aishell/ASR/run.sh
index 4d912fa76..f150b6337 100755
--- a/.github/scripts/aishell/ASR/run.sh
+++ b/.github/scripts/aishell/ASR/run.sh
@@ -263,6 +263,76 @@ function test_transducer_stateless_modified_2_2022_03_01() {
rm -rf $repo
}
+function test_conformer_ctc() {
+ repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+ pushd $repo
+
+ git lfs pull --include "exp/pretrained.pt"
+ git lfs pull --include "data/lang_char/H.fst"
+ git lfs pull --include "data/lang_char/HL.fst"
+ git lfs pull --include "data/lang_char/HLG.fst"
+
+ popd
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ log "CTC decoding"
+
+ log "Exporting model with torchscript"
+
+ pushd $repo/exp
+ ln -s pretrained.pt epoch-99.pt
+ popd
+
+ ./conformer_ctc/export.py \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_char/tokens.txt \
+ --jit 1
+
+ ls -lh $repo/exp
+
+ ls -lh $repo/data/lang_char
+
+ log "Decoding with H on CPU with OpenFst"
+
+ ./conformer_ctc/jit_pretrained_decode_with_H.py \
+ --nn-model $repo/exp/cpu_jit.pt \
+ --H $repo/data/lang_char/H.fst \
+ --tokens $repo/data/lang_char/tokens.txt \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ log "Decoding with HL on CPU with OpenFst"
+
+ ./conformer_ctc/jit_pretrained_decode_with_HL.py \
+ --nn-model $repo/exp/cpu_jit.pt \
+ --HL $repo/data/lang_char/HL.fst \
+ --words $repo/data/lang_char/words.txt \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ log "Decoding with HLG on CPU with OpenFst"
+
+ ./conformer_ctc/jit_pretrained_decode_with_HLG.py \
+ --nn-model $repo/exp/cpu_jit.pt \
+ --HLG $repo/data/lang_char/HLG.fst \
+ --words $repo/data/lang_char/words.txt \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ rm -rf $repo
+}
+
download_test_dev_manifests
test_transducer_stateless3_2022_06_20
test_zipformer_large_2023_10_24
@@ -270,5 +340,4 @@ test_zipformer_2023_10_24
test_zipformer_small_2023_10_24
test_transducer_stateless_modified_2022_03_01
test_transducer_stateless_modified_2_2022_03_01
-
-ls -lh
+# test_conformer_ctc # fails for torch 1.13.x and torch 2.0.x
diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh
index 641d59458..7e9bd8a47 100755
--- a/.github/scripts/librispeech/ASR/run.sh
+++ b/.github/scripts/librispeech/ASR/run.sh
@@ -10,52 +10,1594 @@ log() {
cd egs/librispeech/ASR
-# We don't download the LM file since it is so large that it will
-# cause OOM error for CI later.
-mkdir -p download/lm
-pushd download/lm
-wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
-wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
-wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
-ls -lh
-gunzip librispeech-lm-norm.txt.gz
+function prepare_data() {
+ # We don't download the LM file since it is so large that it will
+ # cause OOM error for CI later.
+ mkdir -p download/lm
+ pushd download/lm
+ wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt
+ wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt
+ wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
+ ls -lh
+ gunzip librispeech-lm-norm.txt.gz
-ls -lh
-popd
+ ls -lh
+ popd
-pushd download/
-wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2
-tar xf LibriSpeech.tar.bz2
-rm LibriSpeech.tar.bz2
+ pushd download/
+ wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2
+ tar xf LibriSpeech.tar.bz2
+ rm LibriSpeech.tar.bz2
-cd LibriSpeech
-ln -s train-clean-100 train-clean-360
-ln -s train-other-500 train-other-500
-popd
+ cd LibriSpeech
+ ln -s train-clean-100 train-clean-360
+ ln -s train-other-500 train-other-500
+ popd
-mkdir -p data/manifests
+ mkdir -p data/manifests
-lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests
-ls -lh data/manifests
+ lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests
+ ls -lh data/manifests
-./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False
-ls -lh data/fbank
+ ./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False
+ ls -lh data/fbank
-./prepare.sh --stage 5 --stop-stage 6
+ ./prepare.sh --stage 5 --stop-stage 6
+}
-./zipformer/train.py \
- --world-size 1 \
- --num-epochs 1 \
- --start-epoch 1 \
- --use-fp16 0 \
- --exp-dir zipformer/exp-small \
- --causal 0 \
- --num-encoder-layers 1,1,1,1,1,1 \
- --feedforward-dim 64,96,96,96,96,96 \
- --encoder-dim 32,64,64,64,64,64 \
- --encoder-unmasked-dim 32,32,32,32,32,32 \
- --base-lr 0.04 \
- --full-libri 0 \
- --enable-musan 0 \
- --max-duration 30 \
- --print-diagnostics 1
+function run_diagnostics() {
+ ./zipformer/train.py \
+ --world-size 1 \
+ --num-epochs 1 \
+ --start-epoch 1 \
+ --use-fp16 0 \
+ --exp-dir zipformer/exp-small \
+ --causal 0 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 64,96,96,96,96,96 \
+ --encoder-dim 32,64,64,64,64,64 \
+ --encoder-unmasked-dim 32,32,32,32,32,32 \
+ --base-lr 0.04 \
+ --full-libri 0 \
+ --enable-musan 0 \
+ --max-duration 30 \
+ --print-diagnostics 1
+}
+
+function test_pruned_transducer_stateless_2022_03_12() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in fast_beam_search modified_beam_search beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless2_2022_04_29() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29
+
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ pushd $repo
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt"
+ popd
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-epoch-38-avg-10.pt pretrained.pt
+ popd
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless3_2022_04_29() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
+
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+ pushd $repo
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt"
+ popd
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-epoch-25-avg-6.pt pretrained.pt
+ popd
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless5_2022_05_13() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-epoch-39-avg-7.pt pretrained.pt
+ popd
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless5/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless5/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav \
+ --num-encoder-layers 18 \
+ --dim-feedforward 2048 \
+ --nhead 8 \
+ --encoder-dim 512 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless7_2022_11_11() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/cpu_jit.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./pruned_transducer_stateless7/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless7/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless8_2022_11_14() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/cpu_jit.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless8/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "Export to torchscript model"
+ ./pruned_transducer_stateless8/export.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model false \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless8/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless8/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless8/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless7_ctc_2022_12_01() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
+
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/HLG.pt"
+ git lfs pull --include "data/lang_bpe_500/L.pt"
+ git lfs pull --include "data/lang_bpe_500/LG.pt"
+ git lfs pull --include "data/lang_bpe_500/Linv.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "data/lm/G_4_gram.pt"
+ git lfs pull --include "exp/cpu_jit.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./pruned_transducer_stateless7_ctc/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless7_ctc/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
+ --model-filename $repo/exp/cpu_jit.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7_ctc/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7_ctc/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_zipformer_mmi_2022_12_08() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
+
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/3gram.pt"
+ git lfs pull --include "data/lang_bpe_500/4gram.pt"
+ git lfs pull --include "data/lang_bpe_500/L.pt"
+ git lfs pull --include "data/lang_bpe_500/LG.pt"
+ git lfs pull --include "data/lang_bpe_500/Linv.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/cpu_jit.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./zipformer_mmi/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./zipformer_mmi/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ --lang-dir $repo/data/lang_bpe_500 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
+ log "$method"
+
+ ./zipformer_mmi/pretrained.py \
+ --method $method \
+ --checkpoint $repo/exp/pretrained.pt \
+ --lang-dir $repo/data/lang_bpe_500 \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless7_streaming_2022_12_29() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/cpu_jit.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ git lfs pull --include "exp/encoder_jit_trace.pt"
+ git lfs pull --include "exp/decoder_jit_trace.pt"
+ git lfs pull --include "exp/joiner_jit_trace.pt"
+ cd exp
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./pruned_transducer_stateless7_streaming/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --decode-chunk-len 32 \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless7_streaming/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "Export to torchscript model by torch.jit.trace()"
+ ./pruned_transducer_stateless7_streaming/jit_trace_export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --decode-chunk-len 32 \
+ --epoch 99 \
+ --avg 1
+
+ log "Decode with models exported by torch.jit.trace()"
+
+ ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
+ --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
+ --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7_streaming/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7_streaming/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --decode-chunk-len 32 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless7_ctc_bs_2023_01_29() {
+ repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
+
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/L.pt"
+ git lfs pull --include "data/lang_bpe_500/LG.pt"
+ git lfs pull --include "data/lang_bpe_500/HLG.pt"
+ git lfs pull --include "data/lang_bpe_500/Linv.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "exp/cpu_jit.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./pruned_transducer_stateless7_ctc_bs/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --nn-model-filename $repo/exp/cpu_jit.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
+ --model-filename $repo/exp/cpu_jit.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for m in ctc-decoding 1best; do
+ ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_conformer_ctc3_2022_11_27() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
+
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/HLG.pt"
+ git lfs pull --include "data/lang_bpe_500/L.pt"
+ git lfs pull --include "data/lang_bpe_500/LG.pt"
+ git lfs pull --include "data/lang_bpe_500/Linv.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "data/lm/G_4_gram.pt"
+ git lfs pull --include "exp/jit_trace.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Decode with models exported by torch.jit.trace()"
+
+ for m in ctc-decoding 1best; do
+ ./conformer_ctc3/jit_pretrained.py \
+ --model-filename $repo/exp/jit_trace.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ log "Export to torchscript model"
+
+ ./conformer_ctc3/export.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --jit-trace 1 \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.trace()"
+
+ for m in ctc-decoding 1best; do
+ ./conformer_ctc3/jit_pretrained.py \
+ --model-filename $repo/exp/jit_trace.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for m in ctc-decoding 1best; do
+ ./conformer_ctc3/pretrained.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $m \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_lstm_transducer_stateless2_2022_09_03() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+ abs_repo=$(realpath $repo)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-iter-468000-avg-16.pt pretrained.pt
+ ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt
+ popd
+
+ log "Test exporting with torch.jit.trace()"
+
+ ./lstm_transducer_stateless2/export.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --use-averaged-model 0 \
+ --jit-trace 1
+
+ log "Decode with models exported by torch.jit.trace()"
+
+ ./lstm_transducer_stateless2/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
+ --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
+ --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./lstm_transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./lstm_transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_pruned_transducer_stateless3_2022_05_13() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
+ ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
+ popd
+
+
+ log "Export to torchscript model"
+ ./pruned_transducer_stateless3/export.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ./pruned_transducer_stateless3/export.py \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit-trace 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.trace()"
+
+ ./pruned_transducer_stateless3/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
+ --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
+ --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./pruned_transducer_stateless3/jit_pretrained.py \
+ --bpe-model $repo/data/lang_bpe_500/bpe.model \
+ --encoder-model-filename $repo/exp/encoder_jit_script.pt \
+ --decoder-model-filename $repo/exp/decoder_jit_script.pt \
+ --joiner-model-filename $repo/exp/joiner_jit_script.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless3/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ rm -rf $repo
+}
+
+function test_streaming_pruned_transducer_stateless2_20220625() {
+ repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ ln -s pretrained-epoch-24-avg-10.pt pretrained.pt
+ popd
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./pruned_transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --simulate-streaming 1 \
+ --causal-convolution 1 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_streaming_zipformer_2023_05_17() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "data/lang_bpe_500/tokens.txt"
+ git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./zipformer/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./zipformer/jit_pretrained_streaming.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
+ $repo/test_wavs/1089-134686-0001.wav
+
+ for method in greedy_search modified_beam_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_zipformer_2023_05_18() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "data/lang_bpe_500/tokens.txt"
+ git lfs pull --include "exp/jit_script.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./zipformer/export.py \
+ --exp-dir $repo/exp \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ ./zipformer/jit_pretrained.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --nn-model-filename $repo/exp/jit_script.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ for method in greedy_search modified_beam_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_transducer_stateless2_torchaudio_2022_04_19() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless2/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in fast_beam_search modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless2/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_zipformer_transducer_ctc_2023_06_13() {
+ repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ pushd $repo/exp
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "data/lang_bpe_500/tokens.txt"
+ git lfs pull --include "data/lang_bpe_500/HLG.pt"
+ git lfs pull --include "data/lang_bpe_500/L.pt"
+ git lfs pull --include "data/lang_bpe_500/LG.pt"
+ git lfs pull --include "data/lang_bpe_500/Linv.pt"
+ git lfs pull --include "data/lm/G_4_gram.pt"
+ git lfs pull --include "exp/jit_script.pt"
+ git lfs pull --include "exp/pretrained.pt"
+ ln -s pretrained.pt epoch-99.pt
+ ls -lh *.pt
+ popd
+
+ log "Export to torchscript model"
+ ./zipformer/export.py \
+ --exp-dir $repo/exp \
+ --use-transducer 1 \
+ --use-ctc 1 \
+ --use-averaged-model false \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 99 \
+ --avg 1 \
+ --jit 1
+
+ ls -lh $repo/exp/*.pt
+
+ log "Decode with models exported by torch.jit.script()"
+
+ for method in ctc-decoding 1best; do
+ ./zipformer/jit_pretrained_ctc.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --model-filename $repo/exp/jit_script.pt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --G $repo/data/lm/G_4_gram.pt \
+ --method $method \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in ctc-decoding 1best; do
+ log "$method"
+
+ ./zipformer/pretrained_ctc.py \
+ --use-transducer 1 \
+ --use-ctc 1 \
+ --method $method \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ --G $repo/data/lm/G_4_gram.pt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --sample-rate 16000 \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_100h_transducer_stateless_multi_datasets_bpe_500_2022_02_21() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_transducer_stateless_multi_datasets_bpe_500_2022_03_01() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in modified_beam_search beam_search fast_beam_search; do
+ log "$method"
+
+ ./transducer_stateless_multi_datasets/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_transducer_stateless_bpe_500_2022_02_07() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ for sym in 1 2 3; do
+ log "Greedy search with --max-sym-per-frame $sym"
+
+ ./transducer_stateless/pretrained.py \
+ --method greedy_search \
+ --max-sym-per-frame $sym \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+
+ for method in fast_beam_search modified_beam_search beam_search; do
+ log "$method"
+
+ ./transducer_stateless/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+ done
+ rm -rf $repo
+}
+
+function test_zipformer_ctc_en_2023_10_02() {
+ repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ log "CTC greedy search"
+
+ ./zipformer/onnx_pretrained_ctc.py \
+ --nn-model $repo/model.onnx \
+ --tokens $repo/tokens.txt \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ log "CTC H decoding"
+
+ ./zipformer/onnx_pretrained_ctc_H.py \
+ --nn-model $repo/model.onnx \
+ --tokens $repo/tokens.txt \
+ --H $repo/H.fst \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ log "CTC HL decoding"
+
+ ./zipformer/onnx_pretrained_ctc_HL.py \
+ --nn-model $repo/model.onnx \
+ --words $repo/words.txt \
+ --HL $repo/HL.fst \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ log "CTC HLG decoding"
+
+ ./zipformer/onnx_pretrained_ctc_HLG.py \
+ --nn-model $repo/model.onnx \
+ --words $repo/words.txt \
+ --HLG $repo/HLG.fst \
+ $repo/test_wavs/0.wav \
+ $repo/test_wavs/1.wav \
+ $repo/test_wavs/2.wav
+
+ rm -rf $repo
+}
+
+function test_conformer_ctc_jit_bpe_500_2021_11_09() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09
+ log "Downloading pre-trained model from $repo_url"
+ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+ repo=$(basename $repo_url)
+ pushd $repo
+
+ git lfs pull --include "exp/pretrained.pt"
+ git lfs pull --include "data/lang_bpe_500/HLG.pt"
+ git lfs pull --include "data/lang_bpe_500/L.pt"
+ git lfs pull --include "data/lang_bpe_500/L_disambig.pt"
+ git lfs pull --include "data/lang_bpe_500/Linv.pt"
+ git lfs pull --include "data/lang_bpe_500/bpe.model"
+ git lfs pull --include "data/lang_bpe_500/lexicon.txt"
+ git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt"
+ git lfs pull --include "data/lang_bpe_500/tokens.txt"
+ git lfs pull --include "data/lang_bpe_500/words.txt"
+ git lfs pull --include "data/lm/G_3_gram.fst.txt"
+
+ popd
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ log "CTC decoding"
+
+ ./conformer_ctc/pretrained.py \
+ --method ctc-decoding \
+ --num-classes 500 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "HLG decoding"
+
+ ./conformer_ctc/pretrained.py \
+ --method 1best \
+ --num-classes 500 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --words-file $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.pt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "CTC decoding on CPU with kaldi decoders using OpenFst"
+
+ log "Exporting model with torchscript"
+
+ pushd $repo/exp
+ ln -s pretrained.pt epoch-99.pt
+ popd
+
+ ./conformer_ctc/export.py \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir $repo/exp \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --jit 1
+
+ ls -lh $repo/exp
+
+
+ log "Generating H.fst, HL.fst"
+
+ ./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt
+
+ ls -lh $repo/data/lang_bpe_500
+
+ log "Decoding with H on CPU with OpenFst"
+
+ ./conformer_ctc/jit_pretrained_decode_with_H.py \
+ --nn-model $repo/exp/cpu_jit.pt \
+ --H $repo/data/lang_bpe_500/H.fst \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "Decoding with HL on CPU with OpenFst"
+
+ ./conformer_ctc/jit_pretrained_decode_with_HL.py \
+ --nn-model $repo/exp/cpu_jit.pt \
+ --HL $repo/data/lang_bpe_500/HL.fst \
+ --words $repo/data/lang_bpe_500/words.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ log "Decoding with HLG on CPU with OpenFst"
+
+ ./conformer_ctc/jit_pretrained_decode_with_HLG.py \
+ --nn-model $repo/exp/cpu_jit.pt \
+ --HLG $repo/data/lang_bpe_500/HLG.fst \
+ --words $repo/data/lang_bpe_500/words.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ rm -rf $repo
+}
+
+function test_transducer_bpe_500_2021_12_23() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ log "Display test files"
+ tree $repo/
+ ls -lh $repo/test_wavs/*.wav
+
+ log "Beam search decoding"
+
+ ./transducer/pretrained.py \
+ --method beam_search \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ $repo/test_wavs/1089-134686-0001.wav \
+ $repo/test_wavs/1221-135766-0001.wav \
+ $repo/test_wavs/1221-135766-0002.wav
+
+ rm -rf $repo
+}
+
+prepare_data
+run_diagnostics
+test_pruned_transducer_stateless_2022_03_12
+test_pruned_transducer_stateless2_2022_04_29
+test_pruned_transducer_stateless3_2022_04_29
+test_pruned_transducer_stateless5_2022_05_13
+test_pruned_transducer_stateless7_2022_11_11
+test_pruned_transducer_stateless8_2022_11_14
+test_pruned_transducer_stateless7_ctc_2022_12_01
+test_zipformer_mmi_2022_12_08
+test_pruned_transducer_stateless7_streaming_2022_12_29
+test_pruned_transducer_stateless7_ctc_bs_2023_01_29
+test_conformer_ctc3_2022_11_27
+test_lstm_transducer_stateless2_2022_09_03
+test_pruned_transducer_stateless3_2022_05_13
+test_streaming_pruned_transducer_stateless2_20220625
+test_streaming_zipformer_2023_05_17
+test_zipformer_2023_05_18
+test_transducer_stateless2_torchaudio_2022_04_19
+test_zipformer_transducer_ctc_2023_06_13
+test_100h_transducer_stateless_multi_datasets_bpe_500_2022_02_21
+test_transducer_stateless_multi_datasets_bpe_500_2022_03_01
+test_transducer_stateless_bpe_500_2022_02_07
+test_zipformer_ctc_en_2023_10_02
+# test_conformer_ctc_jit_bpe_500_2021_11_09 # failes for torch != 1.13.x and torch != 2.0.x
+test_transducer_bpe_500_2021_12_23
diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
deleted file mode 100755
index f6fe8c9b2..000000000
--- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
+++ /dev/null
@@ -1,122 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27
-
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/HLG.pt"
-git lfs pull --include "data/lang_bpe_500/L.pt"
-git lfs pull --include "data/lang_bpe_500/LG.pt"
-git lfs pull --include "data/lang_bpe_500/Linv.pt"
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "data/lm/G_4_gram.pt"
-git lfs pull --include "exp/jit_trace.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Decode with models exported by torch.jit.trace()"
-
-for m in ctc-decoding 1best; do
- ./conformer_ctc3/jit_pretrained.py \
- --model-filename $repo/exp/jit_trace.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --G $repo/data/lm/G_4_gram.pt \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-log "Export to torchscript model"
-
-./conformer_ctc3/export.py \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --jit-trace 1 \
- --epoch 99 \
- --avg 1 \
- --use-averaged-model 0
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.trace()"
-
-for m in ctc-decoding 1best; do
- ./conformer_ctc3/jit_pretrained.py \
- --model-filename $repo/exp/jit_trace.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --G $repo/data/lm/G_4_gram.pt \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for m in ctc-decoding 1best; do
- ./conformer_ctc3/pretrained.py \
- --checkpoint $repo/exp/pretrained.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --G $repo/data/lm/G_4_gram.pt \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p conformer_ctc3/exp
- ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh conformer_ctc3/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in ctc-decoding 1best; do
- log "Decoding with $method"
- ./conformer_ctc3/decode.py \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --exp-dir conformer_ctc3/exp/ \
- --max-duration $max_duration \
- --decoding-method $method \
- --lm-dir data/lm
- done
-
- rm conformer_ctc3/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
deleted file mode 100755
index 412e3ad56..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in fast_beam_search modified_beam_search beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless/exp
- done
-
- rm pruned_transducer_stateless/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
deleted file mode 100755
index 243b669ed..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
+++ /dev/null
@@ -1,86 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29
-
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-pushd $repo
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt"
-popd
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-epoch-38-avg-10.pt pretrained.pt
-popd
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless2/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless2/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless2/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless2/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless2/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless2/exp
- done
-
- rm pruned_transducer_stateless2/exp/*.pt
- rm -r data/lang_bpe_500
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
deleted file mode 100755
index 2d0f80304..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
+++ /dev/null
@@ -1,85 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29
-
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-pushd $repo
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt"
-popd
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-epoch-25-avg-6.pt pretrained.pt
-popd
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless3/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless3/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless3/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless3/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless3/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless3/exp
- done
-
- rm pruned_transducer_stateless3/exp/*.pt
- rm -r data/lang_bpe_500
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
deleted file mode 100755
index 3d5814c48..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
+++ /dev/null
@@ -1,123 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
-ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
-popd
-
-
-log "Export to torchscript model"
-./pruned_transducer_stateless3/export.py \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-./pruned_transducer_stateless3/export.py \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit-trace 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.trace()"
-
-./pruned_transducer_stateless3/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
- --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
- --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless3/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo/exp/encoder_jit_script.pt \
- --decoder-model-filename $repo/exp/decoder_jit_script.pt \
- --joiner-model-filename $repo/exp/joiner_jit_script.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless3/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless3/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless3/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless3/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless3/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless3/exp
- done
-
- rm pruned_transducer_stateless3/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
deleted file mode 100755
index 3d2442d54..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
+++ /dev/null
@@ -1,100 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-epoch-39-avg-7.pt pretrained.pt
-popd
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless5/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --num-encoder-layers 18 \
- --dim-feedforward 2048 \
- --nhead 8 \
- --encoder-dim 512 \
- --decoder-dim 512 \
- --joiner-dim 512 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless5/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav \
- --num-encoder-layers 18 \
- --dim-feedforward 2048 \
- --nhead 8 \
- --encoder-dim 512 \
- --decoder-dim 512 \
- --joiner-dim 512
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless5/exp
- ln -s $PWD/$repo/exp/pretrained-epoch-39-avg-7.pt pruned_transducer_stateless5/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless5/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless5/decode.py \
- --decoding-method $method \
- --use-averaged-model 0 \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless5/exp \
- --num-encoder-layers 18 \
- --dim-feedforward 2048 \
- --nhead 8 \
- --encoder-dim 512 \
- --decoder-dim 512 \
- --joiner-dim 512
- done
-
- rm pruned_transducer_stateless5/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
deleted file mode 100755
index 961dde4f4..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
+++ /dev/null
@@ -1,106 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/cpu_jit.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./pruned_transducer_stateless7/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless7/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless7/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless7/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless7/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless7/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless7/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless7/exp
- done
-
- rm pruned_transducer_stateless7/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
deleted file mode 100755
index ba7139efb..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
+++ /dev/null
@@ -1,150 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01
-
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/HLG.pt"
-git lfs pull --include "data/lang_bpe_500/L.pt"
-git lfs pull --include "data/lang_bpe_500/LG.pt"
-git lfs pull --include "data/lang_bpe_500/Linv.pt"
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "data/lm/G_4_gram.pt"
-git lfs pull --include "exp/cpu_jit.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./pruned_transducer_stateless7_ctc/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless7_ctc/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-for m in ctc-decoding 1best; do
- ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
- --model-filename $repo/exp/cpu_jit.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --G $repo/data/lm/G_4_gram.pt \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless7_ctc/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless7_ctc/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for m in ctc-decoding 1best; do
- ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
- --checkpoint $repo/exp/pretrained.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --G $repo/data/lm/G_4_gram.pt \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless7_ctc/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless7_ctc/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless7_ctc/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless7_ctc/exp
- done
-
- for m in ctc-decoding 1best; do
- ./pruned_transducer_stateless7_ctc/ctc_decode.py \
- --epoch 999 \
- --avg 1 \
- --exp-dir ./pruned_transducer_stateless7_ctc/exp \
- --max-duration $max_duration \
- --use-averaged-model 0 \
- --decoding-method $m \
- --hlg-scale 0.6 \
- --lm-dir data/lm
- done
-
- rm pruned_transducer_stateless7_ctc/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
deleted file mode 100755
index 1ecbc4798..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
+++ /dev/null
@@ -1,147 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
-
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/L.pt"
-git lfs pull --include "data/lang_bpe_500/LG.pt"
-git lfs pull --include "data/lang_bpe_500/HLG.pt"
-git lfs pull --include "data/lang_bpe_500/Linv.pt"
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/cpu_jit.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./pruned_transducer_stateless7_ctc_bs/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-for m in ctc-decoding 1best; do
- ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
- --model-filename $repo/exp/cpu_jit.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for m in ctc-decoding 1best; do
- ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
- --checkpoint $repo/exp/pretrained.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --method $m \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless7_ctc_bs/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless7_ctc_bs/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless7_ctc_bs/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless7_ctc_bs/exp
- done
-
- for m in ctc-decoding 1best; do
- ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \
- --epoch 999 \
- --avg 1 \
- --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
- --max-duration $max_duration \
- --use-averaged-model 0 \
- --decoding-method $m \
- --hlg-scale 0.6
- done
-
- rm pruned_transducer_stateless7_ctc_bs/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
deleted file mode 100755
index 37b192a57..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
+++ /dev/null
@@ -1,148 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/cpu_jit.pt"
-git lfs pull --include "exp/pretrained.pt"
-git lfs pull --include "exp/encoder_jit_trace.pt"
-git lfs pull --include "exp/decoder_jit_trace.pt"
-git lfs pull --include "exp/joiner_jit_trace.pt"
-cd exp
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./pruned_transducer_stateless7_streaming/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --decode-chunk-len 32 \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless7_streaming/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- --decode-chunk-len 32 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "Export to torchscript model by torch.jit.trace()"
-./pruned_transducer_stateless7_streaming/jit_trace_export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --decode-chunk-len 32 \
- --epoch 99 \
- --avg 1
-
-log "Decode with models exported by torch.jit.trace()"
-
-./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --encoder-model-filename $repo/exp/encoder_jit_trace.pt \
- --decoder-model-filename $repo/exp/decoder_jit_trace.pt \
- --joiner-model-filename $repo/exp/joiner_jit_trace.pt \
- --decode-chunk-len 32 \
- $repo/test_wavs/1089-134686-0001.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless7_streaming/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --decode-chunk-len 32 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless7_streaming/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --decode-chunk-len 32 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless7_streaming/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless7_streaming/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
- num_decode_stream=200
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "decoding with $method"
-
- ./pruned_transducer_stateless7_streaming/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --decode-chunk-len 32 \
- --exp-dir pruned_transducer_stateless7_streaming/exp
- done
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless7_streaming/streaming_decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --decode-chunk-len 32 \
- --num-decode-streams $num_decode_stream
- --exp-dir pruned_transducer_stateless7_streaming/exp
- done
-
- rm pruned_transducer_stateless7_streaming/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
deleted file mode 100755
index 4f2bfac24..000000000
--- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
+++ /dev/null
@@ -1,115 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/cpu_jit.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless8/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "Export to torchscript model"
-./pruned_transducer_stateless8/export.py \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --use-averaged-model false \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./pruned_transducer_stateless8/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless8/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless8/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless8/exp
- ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless8/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./pruned_transducer_stateless8/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless8/exp
- done
-
- rm pruned_transducer_stateless8/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
deleted file mode 100755
index 5cbdad16d..000000000
--- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
+++ /dev/null
@@ -1,101 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-ln -s pretrained-epoch-24-avg-10.pt pretrained.pt
-popd
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./pruned_transducer_stateless2/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --simulate-streaming 1 \
- --causal-convolution 1 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./pruned_transducer_stateless2/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --simulate-streaming 1 \
- --causal-convolution 1 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p pruned_transducer_stateless2/exp
- ln -s $PWD/$repo/exp/pretrained-epoch-24-avg-10.pt pruned_transducer_stateless2/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh pruned_transducer_stateless2/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Simulate streaming decoding with $method"
-
- ./pruned_transducer_stateless2/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir pruned_transducer_stateless2/exp \
- --simulate-streaming 1 \
- --causal-convolution 1
- done
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Real streaming decoding with $method"
-
- ./pruned_transducer_stateless2/streaming_decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --num-decode-streams 100 \
- --exp-dir pruned_transducer_stateless2/exp \
- --left-context 32 \
- --decode-chunk-size 8 \
- --right-context 0
- done
-
- rm pruned_transducer_stateless2/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
deleted file mode 100755
index f4e2124b1..000000000
--- a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
+++ /dev/null
@@ -1,116 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "data/lang_bpe_500/tokens.txt"
-git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./zipformer/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --causal 1 \
- --chunk-size 16 \
- --left-context-frames 128 \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./zipformer/jit_pretrained_streaming.py \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
- $repo/test_wavs/1089-134686-0001.wav
-
-for method in greedy_search modified_beam_search fast_beam_search; do
- log "$method"
-
- ./zipformer/pretrained.py \
- --causal 1 \
- --chunk-size 16 \
- --left-context-frames 128 \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p zipformer/exp
- ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh zipformer/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Simulated streaming decoding with $method"
-
- ./zipformer/decode.py \
- --causal 1 \
- --chunk-size 16 \
- --left-context-frames 128 \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir zipformer/exp
- done
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Chunk-wise streaming decoding with $method"
-
- ./zipformer/streaming_decode.py \
- --causal 1 \
- --chunk-size 16 \
- --left-context-frames 128 \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir zipformer/exp
- done
-
- rm zipformer/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
deleted file mode 100755
index ff77855a2..000000000
--- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./transducer_stateless2/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in fast_beam_search modified_beam_search beam_search; do
- log "$method"
-
- ./transducer_stateless2/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p transducer_stateless2/exp
- ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh transducer_stateless2/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./transducer_stateless2/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir transducer_stateless2/exp
- done
-
- rm transducer_stateless2/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-zipformer-2023-05-18.sh
deleted file mode 100755
index fb1a0149d..000000000
--- a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh
+++ /dev/null
@@ -1,94 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "data/lang_bpe_500/tokens.txt"
-git lfs pull --include "exp/jit_script.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./zipformer/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./zipformer/jit_pretrained.py \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --nn-model-filename $repo/exp/jit_script.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-for method in greedy_search modified_beam_search fast_beam_search; do
- log "$method"
-
- ./zipformer/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p zipformer/exp
- ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh zipformer/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./zipformer/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir zipformer/exp
- done
-
- rm zipformer/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh b/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
deleted file mode 100755
index 0026d2109..000000000
--- a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
+++ /dev/null
@@ -1,117 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "data/lang_bpe_500/tokens.txt"
-git lfs pull --include "data/lang_bpe_500/HLG.pt"
-git lfs pull --include "data/lang_bpe_500/L.pt"
-git lfs pull --include "data/lang_bpe_500/LG.pt"
-git lfs pull --include "data/lang_bpe_500/Linv.pt"
-git lfs pull --include "data/lm/G_4_gram.pt"
-git lfs pull --include "exp/jit_script.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./zipformer/export.py \
- --exp-dir $repo/exp \
- --use-transducer 1 \
- --use-ctc 1 \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-for method in ctc-decoding 1best; do
- ./zipformer/jit_pretrained_ctc.py \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --model-filename $repo/exp/jit_script.pt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --G $repo/data/lm/G_4_gram.pt \
- --method $method \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in ctc-decoding 1best; do
- log "$method"
-
- ./zipformer/pretrained_ctc.py \
- --use-transducer 1 \
- --use-ctc 1 \
- --method $method \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- --G $repo/data/lm/G_4_gram.pt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --sample-rate 16000 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p zipformer/exp
- ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh zipformer/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in ctc-decoding 1best; do
- log "Decoding with $method"
-
- ./zipformer/ctc_decode.py \
- --use-transducer 1 \
- --use-ctc 1 \
- --decoding-method $method \
- --nbest-scale 1.0 \
- --hlg-scale 0.6 \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --max-duration $max_duration \
- --exp-dir zipformer/exp
- done
-
- rm zipformer/exp/*.pt
-fi
diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
deleted file mode 100755
index c59921055..000000000
--- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
+++ /dev/null
@@ -1,102 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08
-
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-pushd $repo/exp
-git lfs pull --include "data/lang_bpe_500/3gram.pt"
-git lfs pull --include "data/lang_bpe_500/4gram.pt"
-git lfs pull --include "data/lang_bpe_500/L.pt"
-git lfs pull --include "data/lang_bpe_500/LG.pt"
-git lfs pull --include "data/lang_bpe_500/Linv.pt"
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "exp/cpu_jit.pt"
-git lfs pull --include "exp/pretrained.pt"
-ln -s pretrained.pt epoch-99.pt
-ls -lh *.pt
-popd
-
-log "Export to torchscript model"
-./zipformer_mmi/export.py \
- --exp-dir $repo/exp \
- --use-averaged-model false \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --epoch 99 \
- --avg 1 \
- --jit 1
-
-ls -lh $repo/exp/*.pt
-
-log "Decode with models exported by torch.jit.script()"
-
-./zipformer_mmi/jit_pretrained.py \
- --bpe-model $repo/data/lang_bpe_500/bpe.model \
- --nn-model-filename $repo/exp/cpu_jit.pt \
- --lang-dir $repo/data/lang_bpe_500 \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
- log "$method"
-
- ./zipformer_mmi/pretrained.py \
- --method $method \
- --checkpoint $repo/exp/pretrained.pt \
- --lang-dir $repo/data/lang_bpe_500 \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p zipformer_mmi/exp
- ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh zipformer_mmi/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do
- log "Decoding with $method"
-
- ./zipformer_mmi/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --use-averaged-model 0 \
- --nbest-scale 1.2 \
- --hp-scale 1.0 \
- --max-duration $max_duration \
- --lang-dir $repo/data/lang_bpe_500 \
- --exp-dir zipformer_mmi/exp
- done
-
- rm zipformer_mmi/exp/*.pt
-fi
diff --git a/.github/scripts/run-pre-trained-ctc.sh b/.github/scripts/run-pre-trained-ctc.sh
deleted file mode 100755
index 7d6449c9a..000000000
--- a/.github/scripts/run-pre-trained-ctc.sh
+++ /dev/null
@@ -1,240 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-pushd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-log "CTC greedy search"
-
-./zipformer/onnx_pretrained_ctc.py \
- --nn-model $repo/model.onnx \
- --tokens $repo/tokens.txt \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-log "CTC H decoding"
-
-./zipformer/onnx_pretrained_ctc_H.py \
- --nn-model $repo/model.onnx \
- --tokens $repo/tokens.txt \
- --H $repo/H.fst \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-log "CTC HL decoding"
-
-./zipformer/onnx_pretrained_ctc_HL.py \
- --nn-model $repo/model.onnx \
- --words $repo/words.txt \
- --HL $repo/HL.fst \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-log "CTC HLG decoding"
-
-./zipformer/onnx_pretrained_ctc_HLG.py \
- --nn-model $repo/model.onnx \
- --words $repo/words.txt \
- --HLG $repo/HLG.fst \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-rm -rf $repo
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-pushd $repo
-
-git lfs pull --include "exp/pretrained.pt"
-git lfs pull --include "data/lang_bpe_500/HLG.pt"
-git lfs pull --include "data/lang_bpe_500/L.pt"
-git lfs pull --include "data/lang_bpe_500/L_disambig.pt"
-git lfs pull --include "data/lang_bpe_500/Linv.pt"
-git lfs pull --include "data/lang_bpe_500/bpe.model"
-git lfs pull --include "data/lang_bpe_500/lexicon.txt"
-git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt"
-git lfs pull --include "data/lang_bpe_500/tokens.txt"
-git lfs pull --include "data/lang_bpe_500/words.txt"
-git lfs pull --include "data/lm/G_3_gram.fst.txt"
-
-popd
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-log "CTC decoding"
-
-./conformer_ctc/pretrained.py \
- --method ctc-decoding \
- --num-classes 500 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "HLG decoding"
-
-./conformer_ctc/pretrained.py \
- --method 1best \
- --num-classes 500 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --words-file $repo/data/lang_bpe_500/words.txt \
- --HLG $repo/data/lang_bpe_500/HLG.pt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "CTC decoding on CPU with kaldi decoders using OpenFst"
-
-log "Exporting model with torchscript"
-
-pushd $repo/exp
-ln -s pretrained.pt epoch-99.pt
-popd
-
-./conformer_ctc/export.py \
- --epoch 99 \
- --avg 1 \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- --jit 1
-
-ls -lh $repo/exp
-
-
-log "Generating H.fst, HL.fst"
-
-./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt
-
-ls -lh $repo/data/lang_bpe_500
-
-log "Decoding with H on CPU with OpenFst"
-
-./conformer_ctc/jit_pretrained_decode_with_H.py \
- --nn-model $repo/exp/cpu_jit.pt \
- --H $repo/data/lang_bpe_500/H.fst \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "Decoding with HL on CPU with OpenFst"
-
-./conformer_ctc/jit_pretrained_decode_with_HL.py \
- --nn-model $repo/exp/cpu_jit.pt \
- --HL $repo/data/lang_bpe_500/HL.fst \
- --words $repo/data/lang_bpe_500/words.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-log "Decoding with HLG on CPU with OpenFst"
-
-./conformer_ctc/jit_pretrained_decode_with_HLG.py \
- --nn-model $repo/exp/cpu_jit.pt \
- --HLG $repo/data/lang_bpe_500/HLG.fst \
- --words $repo/data/lang_bpe_500/words.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-
-rm -rf $repo
-
-popd
-
-log "Test aishell"
-
-pushd egs/aishell/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc
-log "Downloading pre-trained model from $repo_url"
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-pushd $repo
-
-git lfs pull --include "exp/pretrained.pt"
-git lfs pull --include "data/lang_char/H.fst"
-git lfs pull --include "data/lang_char/HL.fst"
-git lfs pull --include "data/lang_char/HLG.fst"
-
-popd
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-log "CTC decoding"
-
-log "Exporting model with torchscript"
-
-pushd $repo/exp
-ln -s pretrained.pt epoch-99.pt
-popd
-
-./conformer_ctc/export.py \
- --epoch 99 \
- --avg 1 \
- --exp-dir $repo/exp \
- --tokens $repo/data/lang_char/tokens.txt \
- --jit 1
-
-ls -lh $repo/exp
-
-ls -lh $repo/data/lang_char
-
-log "Decoding with H on CPU with OpenFst"
-
-./conformer_ctc/jit_pretrained_decode_with_H.py \
- --nn-model $repo/exp/cpu_jit.pt \
- --H $repo/data/lang_char/H.fst \
- --tokens $repo/data/lang_char/tokens.txt \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-log "Decoding with HL on CPU with OpenFst"
-
-./conformer_ctc/jit_pretrained_decode_with_HL.py \
- --nn-model $repo/exp/cpu_jit.pt \
- --HL $repo/data/lang_char/HL.fst \
- --words $repo/data/lang_char/words.txt \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-log "Decoding with HLG on CPU with OpenFst"
-
-./conformer_ctc/jit_pretrained_decode_with_HLG.py \
- --nn-model $repo/exp/cpu_jit.pt \
- --HLG $repo/data/lang_char/HLG.fst \
- --words $repo/data/lang_char/words.txt \
- $repo/test_wavs/0.wav \
- $repo/test_wavs/1.wav \
- $repo/test_wavs/2.wav
-
-rm -rf $repo
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
deleted file mode 100755
index 7b686328d..000000000
--- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./transducer_stateless_multi_datasets/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./transducer_stateless_multi_datasets/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p transducer_stateless_multi_datasets/exp
- ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh transducer_stateless_multi_datasets/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./transducer_stateless_multi_datasets/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir transducer_stateless_multi_datasets/exp
- done
-
- rm transducer_stateless_multi_datasets/exp/*.pt
-fi
diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
deleted file mode 100755
index a8eeeb514..000000000
--- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./transducer_stateless_multi_datasets/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in modified_beam_search beam_search fast_beam_search; do
- log "$method"
-
- ./transducer_stateless_multi_datasets/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p transducer_stateless_multi_datasets/exp
- ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh transducer_stateless_multi_datasets/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./transducer_stateless_multi_datasets/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir transducer_stateless_multi_datasets/exp
- done
-
- rm transducer_stateless_multi_datasets/exp/*.pt
-fi
diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh
deleted file mode 100755
index 2e2360435..000000000
--- a/.github/scripts/run-pre-trained-transducer-stateless.sh
+++ /dev/null
@@ -1,77 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-for sym in 1 2 3; do
- log "Greedy search with --max-sym-per-frame $sym"
-
- ./transducer_stateless/pretrained.py \
- --method greedy_search \
- --max-sym-per-frame $sym \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-for method in fast_beam_search modified_beam_search beam_search; do
- log "$method"
-
- ./transducer_stateless/pretrained.py \
- --method $method \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
-done
-
-echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
-echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
-if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
- mkdir -p transducer_stateless/exp
- ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt
- ln -s $PWD/$repo/data/lang_bpe_500 data/
-
- ls -lh data
- ls -lh transducer_stateless/exp
-
- log "Decoding test-clean and test-other"
-
- # use a small value for decoding with CPU
- max_duration=100
-
- for method in greedy_search fast_beam_search modified_beam_search; do
- log "Decoding with $method"
-
- ./transducer_stateless/decode.py \
- --decoding-method $method \
- --epoch 999 \
- --avg 1 \
- --max-duration $max_duration \
- --exp-dir transducer_stateless/exp
- done
-
- rm transducer_stateless/exp/*.pt
-fi
diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh
deleted file mode 100755
index b865f8d13..000000000
--- a/.github/scripts/run-pre-trained-transducer.sh
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/usr/bin/env bash
-
-set -e
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
-
-log "Downloading pre-trained model from $repo_url"
-git lfs install
-git clone $repo_url
-repo=$(basename $repo_url)
-
-log "Display test files"
-tree $repo/
-ls -lh $repo/test_wavs/*.wav
-
-log "Beam search decoding"
-
-./transducer/pretrained.py \
- --method beam_search \
- --beam-size 4 \
- --checkpoint $repo/exp/pretrained.pt \
- --tokens $repo/data/lang_bpe_500/tokens.txt \
- $repo/test_wavs/1089-134686-0001.wav \
- $repo/test_wavs/1221-135766-0001.wav \
- $repo/test_wavs/1221-135766-0002.wav
diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml
index 136e117bd..8b0599fca 100644
--- a/.github/workflows/aishell.yml
+++ b/.github/workflows/aishell.yml
@@ -11,22 +11,13 @@ on:
workflow_dispatch:
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
concurrency:
group: aishell-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
- if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'aishell' || github.event_name == 'schedule')
+ if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell')
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/librispeech.yml
similarity index 95%
rename from .github/workflows/train-librispeech.yml
rename to .github/workflows/librispeech.yml
index 79002a881..6e087b10a 100644
--- a/.github/workflows/train-librispeech.yml
+++ b/.github/workflows/librispeech.yml
@@ -1,4 +1,4 @@
-name: train librispeech
+name: librispeech
on:
push:
branches:
@@ -11,7 +11,7 @@ on:
workflow_dispatch:
concurrency:
- group: train-librispeech-${{ github.ref }}
+ group: librispeech-${{ github.ref }}
cancel-in-progress: true
jobs:
@@ -32,7 +32,7 @@ jobs:
python ./.github/scripts/docker/generate_build_matrix.py
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py)
echo "::set-output name=matrix::${MATRIX}"
- train-librispeech:
+ librispeech:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml
deleted file mode 100644
index f092e3c80..000000000
--- a/.github/workflows/run-librispeech-2022-03-12.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-03-12
-# stateless transducer + k2 pruned rnnt-loss
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_03_12-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_03_12:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh
-
- - name: Display decoding results for pruned_transducer_stateless
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless/exp
-
- cd pruned_transducer_stateless
- echo "results for pruned_transducer_stateless"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for pruned_transducer_stateless
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12
- path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml
deleted file mode 100644
index f8f4d9977..000000000
--- a/.github/workflows/run-librispeech-2022-04-29.yml
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-04-29
-# stateless pruned transducer (reworked model) + giga speech
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_04_29-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_04_29:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh
-
- .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh
-
- - name: Display decoding results for pruned_transducer_stateless2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR
- tree pruned_transducer_stateless2/exp
- cd pruned_transducer_stateless2/exp
- echo "===greedy search==="
- find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Display decoding results for pruned_transducer_stateless3
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR
- tree pruned_transducer_stateless3/exp
- cd pruned_transducer_stateless3/exp
- echo "===greedy search==="
- find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for pruned_transducer_stateless2
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29
- path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
-
- - name: Upload decoding results for pruned_transducer_stateless3
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
- path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml
deleted file mode 100644
index dc20185da..000000000
--- a/.github/workflows/run-librispeech-2022-05-13.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-05-13
-# stateless transducer + k2 pruned rnnt-loss + deeper model
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_05_13-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_05_13:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh
-
- - name: Display decoding results for librispeech pruned_transducer_stateless5
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless5/exp
-
- cd pruned_transducer_stateless5
- echo "results for pruned_transducer_stateless5"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech pruned_transducer_stateless5
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13
- path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/
diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
deleted file mode 100644
index 7e378c9a1..000000000
--- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-11-11-stateless7
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_11_11_zipformer-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_11_11_zipformer:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh
-
- - name: Display decoding results for librispeech pruned_transducer_stateless7
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless7/exp
-
- cd pruned_transducer_stateless7
- echo "results for pruned_transducer_stateless7"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech pruned_transducer_stateless7
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11
- path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/
diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
deleted file mode 100644
index a2c1a0ad6..000000000
--- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-11-14-stateless8
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_11_14_zipformer_stateless8:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh
-
- - name: Display decoding results for librispeech pruned_transducer_stateless8
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless8/exp
-
- cd pruned_transducer_stateless8
- echo "results for pruned_transducer_stateless8"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech pruned_transducer_stateless8
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14
- path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
deleted file mode 100644
index 500ab1736..000000000
--- a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-12-01-stateless7-ctc
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-jobs:
- run_librispeech_2022_11_11_zipformer:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh
-
- - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless7_ctc/exp
-
- cd pruned_transducer_stateless7_ctc
- echo "results for pruned_transducer_stateless7_ctc"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===ctc decoding==="
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===1best==="
- find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01
- path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
deleted file mode 100644
index 1a7f9f594..000000000
--- a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright 2022 Zengwei Yao
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-12-08-zipformer-mmi
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_12_08_zipformer-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_12_08_zipformer:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh
-
- - name: Display decoding results for librispeech zipformer-mmi
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./zipformer-mmi/exp
-
- cd zipformer-mmi
- echo "results for zipformer-mmi"
- echo "===1best==="
- find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===nbest==="
- find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===nbest-rescoring-LG==="
- find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===nbest-rescoring-3-gram==="
- find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===nbest-rescoring-4-gram==="
- find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech zipformer-mmi
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08
- path: egs/librispeech/ASR/zipformer_mmi/exp/
diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml
deleted file mode 100644
index 68014e20c..000000000
--- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml
+++ /dev/null
@@ -1,172 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-12-29-stateless7-streaming
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_12_29_zipformer_streaming:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh
-
- - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless7_streaming/exp
-
- cd pruned_transducer_stateless7_streaming
- echo "results for pruned_transducer_stateless7_streaming"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===streaming greedy search==="
- find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===streaming fast_beam_search==="
- find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===streaming modified beam search==="
- find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
-
- - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29
- path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/
diff --git a/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml
deleted file mode 100644
index 821abc25d..000000000
--- a/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2023-01-29-stateless7-ctc-bs
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-jobs:
- run_librispeech_2023_01_29_zipformer_ctc_bs:
- if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
-
- - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless7_ctc_bs/exp
-
- cd pruned_transducer_stateless7_ctc_bs
- echo "results for pruned_transducer_stateless7_ctc_bs"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===ctc decoding==="
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===1best==="
- find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29
- path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/
diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
deleted file mode 100644
index 905515dc4..000000000
--- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml
+++ /dev/null
@@ -1,155 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-conformer-ctc3-2022-11-28
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_11_28_conformer_ctc3:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh
-
- - name: Display decoding results for librispeech conformer_ctc3
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./conformer_ctc3/exp
-
- cd conformer_ctc3
- echo "results for conformer_ctc3"
- echo "===ctc-decoding==="
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===1best==="
- find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech conformer_ctc3
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28
- path: egs/librispeech/ASR/conformer_ctc3/exp/
diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
deleted file mode 100644
index 3fb0920bc..000000000
--- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml
+++ /dev/null
@@ -1,157 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-pruned-transducer-stateless3-2022-05-13
-# stateless pruned transducer (reworked model) + giga speech
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_pruned_transducer_stateless3_2022_05_13:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh
-
- - name: Display decoding results for pruned_transducer_stateless3
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR
- tree pruned_transducer_stateless3/exp
- cd pruned_transducer_stateless3/exp
- echo "===greedy search==="
- find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for pruned_transducer_stateless3
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
- path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
deleted file mode 100644
index 67a6f6fc4..000000000
--- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-streaming-2022-06-26
-# streaming conformer stateless transducer2
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_streaming_2022_06_26-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_streaming_2022_06_26:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh
-
- - name: Display decoding results
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./pruned_transducer_stateless2/exp
-
- cd pruned_transducer_stateless2
- echo "results for pruned_transducer_stateless2"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified_beam_search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for pruned_transducer_stateless2
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26
- path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
diff --git a/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml b/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml
deleted file mode 100644
index 5145fb43c..000000000
--- a/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml
+++ /dev/null
@@ -1,174 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-streaming-zipformer-2023-05-18
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2023_05_18_streaming_zipformer-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2023_05_18_streaming_zipformer:
- if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh
-
- - name: Display decoding results for librispeech zipformer
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./zipformer/exp
-
- cd zipformer
-
- echo "results for zipformer, simulated streaming decoding"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "results for zipformer, chunk-wise streaming decoding"
- echo "===greedy search==="
- find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
-
- - name: Upload decoding results for librispeech zipformer
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
- path: egs/librispeech/ASR/zipformer/exp/
diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
deleted file mode 100644
index 35ca08a31..000000000
--- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-2022-04-19
-# stateless transducer + torchaudio rnn-t loss
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2022_04_19-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2022_04_19:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh
-
- - name: Display decoding results
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./transducer_stateless2/exp
-
- cd transducer_stateless2
- echo "results for transducer_stateless2"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified_beam_search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for transducer_stateless2
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19
- path: egs/librispeech/ASR/transducer_stateless2/exp/
diff --git a/.github/workflows/run-librispeech-zipformer-2023-05-18.yml b/.github/workflows/run-librispeech-zipformer-2023-05-18.yml
deleted file mode 100644
index e9d235ad1..000000000
--- a/.github/workflows/run-librispeech-zipformer-2023-05-18.yml
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-zipformer-2023-05-18
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2023_05_18_zipformer-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2023_05_18_zipformer:
- if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-zipformer-2023-05-18.sh
-
- - name: Display decoding results for librispeech zipformer
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./zipformer/exp
-
- cd zipformer
- echo "results for zipformer"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech zipformer
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
- path: egs/librispeech/ASR/zipformer/exp/
diff --git a/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml b/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml
deleted file mode 100644
index 48f0b1532..000000000
--- a/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml
+++ /dev/null
@@ -1,155 +0,0 @@
-# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-librispeech-zipformer-ctc-2023-06-14
-# zipformer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_librispeech_2023_06_14_zipformer-ctc-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_librispeech_2023_06_14_zipformer_ctc:
- if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
-
- - name: Display decoding results for librispeech zipformer
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./zipformer/exp
-
- cd zipformer
- echo "results for zipformer"
- echo "===ctc-decoding==="
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===1best==="
- find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for librispeech zipformer
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
- path: egs/librispeech/ASR/zipformer/exp/
diff --git a/.github/workflows/run-pretrained-ctc.yml b/.github/workflows/run-pretrained-ctc.yml
deleted file mode 100644
index 074a63dfc..000000000
--- a/.github/workflows/run-pretrained-ctc.yml
+++ /dev/null
@@ -1,87 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-ctc
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- workflow_dispatch:
- inputs:
- test-run:
- description: 'Test (y/n)?'
- required: true
- default: 'y'
-
-concurrency:
- group: run_pre_trained_ctc-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_ctc:
- if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Inference with pre-trained model
- shell: bash
- run: |
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-pre-trained-ctc.sh
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
deleted file mode 100644
index f8caee8e5..000000000
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-100h
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh
-
- - name: Display decoding results for transducer_stateless_multi_datasets
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./transducer_stateless_multi_datasets/exp
-
- cd transducer_stateless_multi_datasets
- echo "results for transducer_stateless_multi_datasets"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for transducer_stateless_multi_datasets
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
- path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
deleted file mode 100644
index 7c3910eb8..000000000
--- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-960h
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh
-
- - name: Display decoding results for transducer_stateless_multi_datasets
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./transducer_stateless_multi_datasets/exp
-
- cd transducer_stateless_multi_datasets
- echo "results for transducer_stateless_multi_datasets"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for transducer_stateless_multi_datasets
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
- path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml
deleted file mode 100644
index 1b69b97bf..000000000
--- a/.github/workflows/run-pretrained-transducer-stateless.yml
+++ /dev/null
@@ -1,158 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-transducer-stateless
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
- schedule:
- # minute (0-59)
- # hour (0-23)
- # day of the month (1-31)
- # month (1-12)
- # day of the week (0-6)
- # nightly build at 15:50 UTC time every day
- - cron: "50 15 * * *"
-
-concurrency:
- group: run_pre_trained_transducer_stateless-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_transducer_stateless:
- if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/install-kaldifeat.sh
-
- - name: Cache LibriSpeech test-clean and test-other datasets
- id: libri-test-clean-and-test-other-data
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/download
- key: cache-libri-test-clean-and-test-other
-
- - name: Download LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh
-
- - name: Prepare manifests for LibriSpeech test-clean and test-other
- shell: bash
- run: |
- .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh
-
- - name: Cache LibriSpeech test-clean and test-other fbank features
- id: libri-test-clean-and-test-other-fbank
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/fbank-libri
- key: cache-libri-fbank-test-clean-and-test-other-v2
-
- - name: Compute fbank for LibriSpeech test-clean and test-other
- if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true'
- shell: bash
- run: |
- .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh
-
- - name: Inference with pre-trained model
- shell: bash
- env:
- GITHUB_EVENT_NAME: ${{ github.event_name }}
- GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
- run: |
- mkdir -p egs/librispeech/ASR/data
- ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank
- ls -lh egs/librispeech/ASR/data/*
-
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
-
- .github/scripts/run-pre-trained-transducer-stateless.sh
-
- - name: Display decoding results for transducer_stateless
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- shell: bash
- run: |
- cd egs/librispeech/ASR/
- tree ./transducer_stateless/exp
-
- cd transducer_stateless
- echo "results for transducer_stateless"
- echo "===greedy search==="
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===fast_beam_search==="
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- echo "===modified beam search==="
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2
- find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2
-
- - name: Upload decoding results for transducer_stateless
- uses: actions/upload-artifact@v2
- if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
- with:
- name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07
- path: egs/librispeech/ASR/transducer_stateless/exp/
diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml
deleted file mode 100644
index 91d87f1c9..000000000
--- a/.github/workflows/run-pretrained-transducer.yml
+++ /dev/null
@@ -1,80 +0,0 @@
-# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
-
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-name: run-pre-trained-transducer
-
-on:
- push:
- branches:
- - master
- pull_request:
- types: [labeled]
-
-concurrency:
- group: run_pre_trained_transducer-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run_pre_trained_transducer:
- if: github.event.label.name == 'ready' || github.event_name == 'push'
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
- python-version: [3.8]
-
- fail-fast: false
-
- steps:
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0
-
- - name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- cache: 'pip'
- cache-dependency-path: '**/requirements-ci.txt'
-
- - name: Install Python dependencies
- run: |
- grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
- pip uninstall -y protobuf
- pip install --no-binary protobuf protobuf==3.20.*
-
- - name: Cache kaldifeat
- id: my-cache
- uses: actions/cache@v2
- with:
- path: |
- ~/tmp/kaldifeat
- key: cache-tmp-${{ matrix.python-version }}-2023-05-22
-
- - name: Install kaldifeat
- if: steps.my-cache.outputs.cache-hit != 'true'
- shell: bash
- run: |
- make -j2 _kaldifeat
-
- - name: Inference with pre-trained model
- shell: bash
- run: |
- sudo apt-get -qq install git-lfs tree
- export PYTHONPATH=$PWD:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
- export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-pre-trained-transducer.sh
From f42258caf8a1c4d19428d98b808986522f630843 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Sat, 30 Dec 2023 13:03:26 +0800
Subject: [PATCH 029/123] Update compute_fbank_commonvoice_splits.py (#1437)
---
egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
index 0564f6ec6..f31b45aa5 100755
--- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
+++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py
@@ -109,10 +109,10 @@ def compute_fbank_commonvoice_splits(args):
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")
- set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance
+ set_audio_duration_mismatch_tolerance(0.05) # 50ms tolerance
set_caching_enabled(False)
for i in range(start, stop):
- idx = f"{i + 1}".zfill(num_digits)
+ idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz"
From 8136ad775b6cd02bf2ecc60d65e8641b709c2d41 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Thu, 4 Jan 2024 13:59:32 +0800
Subject: [PATCH 030/123] Use high_freq -400 in computing fbank features.
(#1447)
See also https://github.com/k2-fsa/sherpa-onnx/issues/514
---
.../ASR/pruned_transducer_stateless2/pretrained.py | 1 +
egs/aishell/ASR/conformer_ctc/pretrained.py | 1 +
egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py | 1 +
egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py | 1 +
egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py | 1 +
egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py | 1 +
egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py | 1 +
.../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 +
egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 1 +
egs/aishell/ASR/transducer_stateless/pretrained.py | 1 +
egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py | 1 +
egs/aishell/ASR/transducer_stateless_modified/pretrained.py | 1 +
egs/aishell/ASR/zipformer/streaming_decode.py | 1 +
egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7/onnx_pretrained.py | 1 +
egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py | 1 +
.../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 +
.../jit_trace_pretrained.py | 1 +
egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py | 1 +
.../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 +
egs/gigaspeech/ASR/zipformer/streaming_decode.py | 1 +
egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py | 1 +
.../ASR/conformer_ctc/jit_pretrained_decode_with_H.py | 1 +
.../ASR/conformer_ctc/jit_pretrained_decode_with_HL.py | 1 +
.../ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py | 1 +
egs/librispeech/ASR/conformer_ctc/pretrained.py | 1 +
egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py | 1 +
egs/librispeech/ASR/conformer_ctc3/pretrained.py | 1 +
.../ASR/conv_emformer_transducer_stateless/streaming_decode.py | 1 +
.../ASR/conv_emformer_transducer_stateless2/jit_pretrained.py | 1 +
.../ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py | 1 +
.../conv_emformer_transducer_stateless2/streaming-ncnn-decode.py | 1 +
.../ASR/conv_emformer_transducer_stateless2/streaming_decode.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py | 1 +
.../ASR/lstm_transducer_stateless/streaming_decode.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py | 1 +
.../ASR/lstm_transducer_stateless2/onnx_pretrained.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py | 1 +
.../ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py | 1 +
.../ASR/lstm_transducer_stateless2/streaming-onnx-decode.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py | 1 +
egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py | 1 +
.../ASR/lstm_transducer_stateless3/streaming_decode.py | 1 +
egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless/streaming_decode.py | 1 +
egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless2/streaming_decode.py | 1 +
.../ASR/pruned_transducer_stateless3/jit_pretrained.py | 1 +
.../ASR/pruned_transducer_stateless3/onnx_pretrained.py | 1 +
egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless3/streaming_decode.py | 1 +
.../ASR/pruned_transducer_stateless4/streaming_decode.py | 1 +
.../pruned_transducer_stateless5/onnx_pretrained-streaming.py | 1 +
egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless5/streaming_decode.py | 1 +
.../ASR/pruned_transducer_stateless7/jit_pretrained.py | 1 +
egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py | 1 +
.../pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py | 1 +
.../ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py | 1 +
.../jit_trace_pretrained.py | 1 +
.../pruned_transducer_stateless7_streaming/onnx_pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_streaming/pretrained.py | 1 +
.../streaming-ncnn-decode.py | 1 +
.../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 +
.../ASR/pruned_transducer_stateless8/jit_pretrained.py | 1 +
egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py | 1 +
egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py | 1 +
egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py | 1 +
egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py | 1 +
egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py | 1 +
egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py | 1 +
egs/librispeech/ASR/transducer/pretrained.py | 1 +
egs/librispeech/ASR/transducer_stateless/pretrained.py | 1 +
egs/librispeech/ASR/transducer_stateless2/pretrained.py | 1 +
.../ASR/transducer_stateless_multi_datasets/pretrained.py | 1 +
egs/librispeech/ASR/zipformer/jit_pretrained.py | 1 +
egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py | 1 +
egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 +
egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 +
egs/librispeech/ASR/zipformer/pretrained.py | 1 +
egs/librispeech/ASR/zipformer/pretrained_ctc.py | 1 +
egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py | 1 +
egs/librispeech/ASR/zipformer_mmi/pretrained.py | 1 +
egs/mgb2/ASR/conformer_ctc/pretrained.py | 1 +
egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
egs/multi_zh-hans/ASR/zipformer/pretrained.py | 1 +
egs/multi_zh_en/ASR/zipformer/pretrained.py | 1 +
egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py | 1 +
.../ASR/pruned_transducer_stateless7_bbpe/pretrained.py | 1 +
egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py | 1 +
egs/tedlium3/ASR/transducer_stateless/pretrained.py | 1 +
egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 1 +
egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless2/jit_pretrained.py | 1 +
egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py | 1 +
.../pruned_transducer_stateless5/onnx_pretrained-streaming.py | 1 +
.../ASR/pruned_transducer_stateless5/onnx_pretrained.py | 1 +
egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
.../ASR/pruned_transducer_stateless5/streaming_decode.py | 1 +
egs/wenetspeech/ASR/zipformer/streaming_decode.py | 1 +
egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py | 1 +
egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py | 1 +
egs/yesno/ASR/tdnn/jit_pretrained.py | 1 +
egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py | 1 +
egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py | 1 +
egs/yesno/ASR/tdnn/onnx_pretrained.py | 1 +
egs/yesno/ASR/tdnn/pretrained.py | 1 +
127 files changed, 127 insertions(+)
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
index 75c316eaf..17729e02e 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py
@@ -242,6 +242,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py
index 66d583396..af1171a6f 100755
--- a/egs/aishell/ASR/conformer_ctc/pretrained.py
+++ b/egs/aishell/ASR/conformer_ctc/pretrained.py
@@ -261,6 +261,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
index 82c10f129..c4aa98358 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py
@@ -240,6 +240,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
index ead393e6e..69fe3a40b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py
@@ -241,6 +241,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py
index e61190649..5143f2cae 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -230,6 +230,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py
index a92182e8d..8e8e971eb 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py
@@ -369,6 +369,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
index 0c43bf74b..8fb7ac278 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
@@ -227,6 +227,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
index ea5bda4db..12004315b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
@@ -250,6 +250,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index 6b4f183cf..aa0e07c83 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -317,6 +317,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
index 7e7213501..9754b4939 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py
@@ -158,6 +158,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index 40f430e13..540e7b61b 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -258,6 +258,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
index 5d8ca2e11..4a4e9237c 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py
@@ -238,6 +238,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
index 9e4459247..66a91709e 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py
@@ -238,6 +238,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py
index c3820447a..f54ffbd3c 100755
--- a/egs/aishell/ASR/zipformer/streaming_decode.py
+++ b/egs/aishell/ASR/zipformer/streaming_decode.py
@@ -572,6 +572,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 100
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
index bc3ae7abf..f04632388 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -239,6 +239,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
index ee898c303..e8b7f71b7 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py
@@ -251,6 +251,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
index f5a0dd8c8..a738bb3fb 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py
@@ -242,6 +242,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py
index cf6ddfa36..52fed7331 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py
@@ -370,6 +370,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py
index a22d1b4ba..b6e2451e8 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index dbe65d0a7..018736d26 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -320,6 +320,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
index d84cf04a3..58ee99e6a 100644
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
@@ -177,6 +177,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
index 932026868..66fbae378 100644
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
@@ -252,6 +252,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index 9700dd89e..7252665a7 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -337,6 +337,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py
index a76788859..09df2935c 100755
--- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py
+++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py
@@ -553,6 +553,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 100
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py
index 48fd2612a..458109a3f 100644
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py
@@ -264,6 +264,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
index 4bdec9e11..e9acf7e0b 100755
--- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
+++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py
@@ -195,6 +195,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
index d5a1dba3c..5753aa5d3 100755
--- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
+++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py
@@ -192,6 +192,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
index 216677a23..b6e3333ce 100755
--- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
+++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py
@@ -191,6 +191,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py
index df3e4d819..38b60fcb9 100755
--- a/egs/librispeech/ASR/conformer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py
@@ -283,6 +283,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
index 76db46cc8..19b26361e 100755
--- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py
@@ -271,6 +271,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
index c37b99cce..a0cdfcf03 100755
--- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py
+++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py
@@ -302,6 +302,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
index e5a7c7116..9b8b4cce2 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
@@ -623,6 +623,7 @@ def create_streaming_feature_extractor() -> Fbank:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return Fbank(opts)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
index 1fe358c79..58f587c91 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py
@@ -184,6 +184,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py
index a6c69d54f..c8aae04e8 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py
@@ -326,6 +326,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
index 74da9e6c8..1047100fc 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py
@@ -276,6 +276,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
index f5d894a7b..aaed7d31f 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
@@ -623,6 +623,7 @@ def create_streaming_feature_extractor() -> Fbank:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
index c07956243..5350a54da 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py
@@ -266,6 +266,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
index 119fcf1fd..42c3a5d7f 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py
@@ -251,6 +251,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
index f989d9bc0..03472e2c3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
@@ -615,6 +615,7 @@ def create_streaming_feature_extractor() -> Fbank:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
index 728b09104..f4ec17221 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py
@@ -267,6 +267,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
index 3eeaa5397..5bab70fb0 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py
@@ -255,6 +255,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py
index 06159e56a..06397965d 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py
@@ -298,6 +298,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
index 5d6d97320..dcff088e2 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py
@@ -254,6 +254,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
index cbbc77928..6166049ae 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py
@@ -217,6 +217,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
index 487fc2114..df9f6cf3f 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py
@@ -344,6 +344,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
index 237591a36..d9e7f3578 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py
@@ -266,6 +266,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
index 29a0d4d1a..e39637bd8 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py
@@ -252,6 +252,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
index c737e3611..c425b1f46 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
@@ -615,6 +615,7 @@ def create_streaming_feature_extractor() -> Fbank:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
index 02f9f1b03..e06404619 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py
@@ -277,6 +277,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
index f4b01fd06..8586c66d6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
@@ -334,6 +334,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 100
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
index 029f55ba0..6923f4d40 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -278,6 +278,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
index 9c4a13606..d17c3467a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
@@ -336,6 +336,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
index 0669284b3..6d09de6bd 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py
@@ -285,6 +285,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
index de3e03da6..8d12eae28 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py
@@ -368,6 +368,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
index abda4e2d4..05e6a6fba 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py
@@ -287,6 +287,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
index e7c1affc2..5e1acd735 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
@@ -337,6 +337,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
index e966aa4b1..229b52e5b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
@@ -353,6 +353,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
index 6e290e799..2432c6010 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
@@ -326,6 +326,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
index 304fa8693..a9ce75a7b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -251,6 +251,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index f65f47fc2..8478a65fb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -353,6 +353,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
index 5af6dae25..88a05e09d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py
@@ -225,6 +225,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
index 86c922cda..4bf11ac24 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
index 280b95984..83dc29324 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py
@@ -224,6 +224,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
index d50d231d5..d1b7eec65 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py
@@ -280,6 +280,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
index 78e0fa778..323ba2642 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
index 904c1deae..1e638aa7d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py
@@ -298,6 +298,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py
index da2c6a39a..a39fdee54 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py
@@ -224,6 +224,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py
index 653c25e06..80604ef4a 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py
@@ -280,6 +280,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
index 494a34d97..0ff110370 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
@@ -381,6 +381,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py
index 5d240cf30..a82f3562b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py
index 914107526..b98756a54 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py
@@ -298,6 +298,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
index c8301b2da..7116b10fb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
@@ -231,6 +231,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
index f2ac1914d..d714670cf 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
@@ -186,6 +186,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
index 04861ea37..298d1889b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
@@ -382,6 +382,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
index bc42e8d05..aa2dd17fb 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
index 883fdcbdd..999f7e0b4 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
@@ -335,6 +335,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index a0f54b6e1..e27fb4e63 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -320,6 +320,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 50
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
index 129497d5a..3ce2953c3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py
@@ -225,6 +225,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
index 64b38c9d5..c29b8d8c9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
index fde724866..b3dfab64a 100755
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py
@@ -196,6 +196,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py
index 3888d3544..0cd876551 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py
@@ -224,6 +224,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
index 6f2cbaabd..92dea3aa1 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
@@ -280,6 +280,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
index 981039b8f..5c6956324 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
@@ -262,6 +262,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
index a06d6d684..7698ada79 100755
--- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
@@ -298,6 +298,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py
index c2413f5de..4d9bbf4b1 100755
--- a/egs/librispeech/ASR/transducer/pretrained.py
+++ b/egs/librispeech/ASR/transducer/pretrained.py
@@ -235,6 +235,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py
index 5898dd0f5..3b86e319e 100755
--- a/egs/librispeech/ASR/transducer_stateless/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py
@@ -247,6 +247,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
index b69b347ef..2de4182f1 100755
--- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py
@@ -247,6 +247,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
index 4f29d6f1f..83094ea51 100755
--- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
+++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py
@@ -247,6 +247,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py
index a41fbc1c9..52dfd3fb6 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py
@@ -222,6 +222,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
index 660a4bfc6..fcd07ae34 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
@@ -285,6 +285,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py
index d4ceacefd..eade5a854 100755
--- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py
+++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py
@@ -167,6 +167,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
index 44546cae5..dd47c0eb6 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py
@@ -318,6 +318,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
index e7c4f40ee..e011c4b24 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
@@ -413,6 +413,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py
index 334376093..662392b5f 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py
@@ -369,6 +369,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py
index eb5cee9cd..ecca758f2 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py
@@ -161,6 +161,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
index 683a7dc20..a77c3bf2a 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py
@@ -225,6 +225,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
logging.info(f"Loading H from {args.H}")
H = kaldifst.StdVectorFst.read(args.H)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
index 0b94bfa65..6ef944514 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py
@@ -223,6 +223,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
logging.info(f"Loading HL from {args.HL}")
HL = kaldifst.StdVectorFst.read(args.HL)
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
index 93569142a..ccb3107ea 100755
--- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py
@@ -223,6 +223,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
logging.info(f"Loading HLG from {args.HLG}")
HLG = kaldifst.StdVectorFst.read(args.HLG)
diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py
index 3104b6084..de0652893 100755
--- a/egs/librispeech/ASR/zipformer/pretrained.py
+++ b/egs/librispeech/ASR/zipformer/pretrained.py
@@ -303,6 +303,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py
index 9dff2e6fc..408d13576 100755
--- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py
+++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py
@@ -304,6 +304,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
index c9ef16ffa..6990c90a0 100755
--- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
+++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py
@@ -259,6 +259,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
index 3ba4da5dd..1e7afc777 100755
--- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py
+++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py
@@ -282,6 +282,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py
index d30ca98d8..0ab2af527 100755
--- a/egs/mgb2/ASR/conformer_ctc/pretrained.py
+++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py
@@ -287,6 +287,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
index 77ba0873b..81a16f0ff 100755
--- a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py
@@ -249,6 +249,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py
index 69ff382da..c15db11f7 100755
--- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py
+++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py
@@ -303,6 +303,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py
index 676272e1f..2fcde550b 100755
--- a/egs/multi_zh_en/ASR/zipformer/pretrained.py
+++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py
@@ -306,6 +306,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index 3305f5bd3..8a74ee745 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -248,6 +248,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
index a23e2a04f..8c966a2f6 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py
@@ -226,6 +226,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
index f365986f6..6e07b5949 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py
@@ -261,6 +261,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
index 8a89c3578..9e58fed00 100644
--- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py
@@ -256,6 +256,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
index 81afd6a4e..5300fe764 100644
--- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py
+++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py
@@ -270,6 +270,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
index 3fdf3b855..0d77bc512 100644
--- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py
@@ -196,6 +196,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
index 98c746ce5..f06c8c211 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py
@@ -196,6 +196,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
index f90dd2b43..aee1a2175 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py
@@ -285,6 +285,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
index c3d67ad92..642de72d7 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py
@@ -238,6 +238,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
index c31db6859..cca26feb0 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py
@@ -327,6 +327,7 @@ def create_streaming_feature_extractor() -> OnlineFeature:
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
return OnlineFbank(opts)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py
index c784853ee..4b4ddd332 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py
@@ -376,6 +376,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
index 1cac20435..17428e19d 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py
@@ -238,6 +238,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 3a4dc3cb8..27a9b1714 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -378,6 +378,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 100
diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py
index 94c5fae5f..96f339b07 100755
--- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py
+++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py
@@ -572,6 +572,7 @@ def decode_dataset(
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
log_interval = 100
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
index 74a2210c3..2c106c4cb 100755
--- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py
@@ -249,6 +249,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
index d05bafcfb..6995ff2ff 100755
--- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py
@@ -260,6 +260,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py
index 7581ecb83..e29415ffb 100755
--- a/egs/yesno/ASR/tdnn/jit_pretrained.py
+++ b/egs/yesno/ASR/tdnn/jit_pretrained.py
@@ -142,6 +142,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py
index ff8c742af..72127aebd 100755
--- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py
+++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py
@@ -164,6 +164,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 23
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py
index 05ba74f9a..f8a057336 100755
--- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py
+++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py
@@ -163,6 +163,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 23
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py
index 72a1d69c8..968a9e9a8 100755
--- a/egs/yesno/ASR/tdnn/onnx_pretrained.py
+++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py
@@ -186,6 +186,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py
index 987c49de6..bea520998 100755
--- a/egs/yesno/ASR/tdnn/pretrained.py
+++ b/egs/yesno/ASR/tdnn/pretrained.py
@@ -164,6 +164,7 @@ def main():
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
fbank = kaldifeat.Fbank(opts)
From 716b82cc3ada9e39254acc93465ed85e53d05670 Mon Sep 17 00:00:00 2001
From: Karel Vesely
Date: Fri, 5 Jan 2024 03:21:27 +0100
Subject: [PATCH 031/123] streaming_decode.py, relax the audio range from
[-1,+1] to [-10,+10] (#1448)
- some AudioTransform classes produce audio signals out of range [-1,+1]
- Resample produced 1.0079
- The range [-10,+10] was chosen to still be able to reliably
distinguish from the [-32k,+32k] signal...
- this is related to : https://github.com/lhotse-speech/lhotse/issues/1254
---
.../streaming_decode.py | 7 ++++++-
egs/aishell/ASR/zipformer/streaming_decode.py | 12 ++++++------
.../streaming_decode.py | 7 ++++++-
egs/gigaspeech/ASR/zipformer/streaming_decode.py | 7 ++++++-
.../streaming_decode.py | 8 +++++++-
.../streaming_decode.py | 8 +++++++-
.../lstm_transducer_stateless/streaming_decode.py | 8 +++++++-
.../lstm_transducer_stateless3/streaming_decode.py | 8 +++++++-
.../pruned_transducer_stateless/streaming_decode.py | 7 ++++++-
.../pruned_transducer_stateless2/streaming_decode.py | 7 ++++++-
.../pruned_transducer_stateless3/streaming_decode.py | 7 ++++++-
.../pruned_transducer_stateless4/streaming_decode.py | 7 ++++++-
.../pruned_transducer_stateless5/streaming_decode.py | 7 ++++++-
.../streaming_decode.py | 7 ++++++-
.../streaming_decode.py | 7 ++++++-
egs/librispeech/ASR/zipformer/streaming_decode.py | 7 ++++++-
.../pruned_transducer_stateless5/streaming_decode.py | 8 ++++++++
egs/wenetspeech/ASR/zipformer/streaming_decode.py | 12 ++++++------
18 files changed, 114 insertions(+), 27 deletions(-)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index aa0e07c83..a4b5cd588 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -342,7 +342,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py
index f54ffbd3c..6a7ef2750 100755
--- a/egs/aishell/ASR/zipformer/streaming_decode.py
+++ b/egs/aishell/ASR/zipformer/streaming_decode.py
@@ -597,12 +597,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- if audio.max() > 1:
- logging.warning(
- f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}."
- f"Clipping to [-1, 1]."
- )
- audio = np.clip(audio, -1, 1)
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index 7252665a7..6a249dd3f 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -362,7 +362,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py
index 09df2935c..7cada8c9d 100755
--- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py
+++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py
@@ -578,7 +578,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
index 9b8b4cce2..12953c74c 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py
@@ -681,8 +681,14 @@ def decode_dataset(
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
+
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
index aaed7d31f..ddc7dbef1 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py
@@ -681,8 +681,14 @@ def decode_dataset(
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
+
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
index 03472e2c3..14cb0fdfe 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py
@@ -673,8 +673,14 @@ def decode_dataset(
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
+
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
index c425b1f46..f57bdea67 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py
@@ -673,8 +673,14 @@ def decode_dataset(
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
+
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
index 8586c66d6..4726d9fad 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py
@@ -359,7 +359,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
index d17c3467a..381561359 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py
@@ -361,7 +361,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
index 5e1acd735..9113cfaa9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py
@@ -362,7 +362,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
index 229b52e5b..f205ad42f 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py
@@ -378,7 +378,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 8478a65fb..1d980f10e 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -378,7 +378,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
index e27fb4e63..0961e0d7b 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -345,7 +345,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
index 2904f086c..cc2787d76 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py
@@ -345,7 +345,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py
index 904caf8af..8087c1460 100755
--- a/egs/librispeech/ASR/zipformer/streaming_decode.py
+++ b/egs/librispeech/ASR/zipformer/streaming_decode.py
@@ -577,7 +577,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
index 27a9b1714..b396aa9b8 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py
@@ -402,6 +402,14 @@ def decode_dataset(
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
+ # The trained model is using normalized samples
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
+
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py
index 96f339b07..cb2cf7d35 100755
--- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py
+++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py
@@ -597,12 +597,12 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
- if audio.max() > 1:
- logging.warning(
- f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}."
- f"Clipping to [-1, 1]."
- )
- audio = np.clip(audio, -1, 1)
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
From b9b56eb879e694684156b6ba441a1c665ff26e19 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Mon, 8 Jan 2024 14:28:07 +0800
Subject: [PATCH 032/123] Minor fixes to the VCTK data prep scripts (#1441)
* Update prepare.sh
---
egs/vctk/TTS/prepare.sh | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh
index 87150ad31..152c7b168 100755
--- a/egs/vctk/TTS/prepare.sh
+++ b/egs/vctk/TTS/prepare.sh
@@ -7,6 +7,7 @@ set -eou pipefail
stage=0
stop_stage=100
+use_edinburgh_vctk_url=true
dl_dir=$PWD/download
@@ -44,7 +45,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# ln -sfv /path/to/VCTK $dl_dir/VCTK
#
if [ ! -d $dl_dir/VCTK ]; then
- lhotse download vctk $dl_dir
+ lhotse download vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir
fi
fi
@@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# to $dl_dir/VCTK
mkdir -p data/manifests
if [ ! -e data/manifests/.vctk.done ]; then
- lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests
+ lhotse prepare vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir/VCTK data/manifests
touch data/manifests/.vctk.done
fi
fi
From 5445ea6df6e250f8f1e9b2df3bb9e54afe104f97 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Mon, 8 Jan 2024 15:09:21 +0800
Subject: [PATCH 033/123] Use shuffled LibriSpeech cuts instead (#1450)
* use shuffled LibriSpeech cuts instead
* leave the old code in comments for reference
---
egs/librispeech/ASR/conformer_ctc3/train.py | 15 ++++++++++++---
egs/librispeech/ASR/conformer_mmi/train.py | 16 +++++++++++++---
.../ASR/lstm_transducer_stateless3/train.py | 15 ++++++++++++---
.../ASR/pruned2_knowledge/train.py | 15 ++++++++++++---
.../train.py | 19 ++++++++++++++++---
.../train.py | 11 ++++++++---
egs/librispeech/ASR/zipformer/train.py | 15 ++++++++++++---
egs/librispeech/ASR/zipformer_mmi/train.py | 8 +++++---
8 files changed, 90 insertions(+), 24 deletions(-)
diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py
index 2cd223945..a2f1125ca 100755
--- a/egs/librispeech/ASR/conformer_ctc3/train.py
+++ b/egs/librispeech/ASR/conformer_ctc3/train.py
@@ -952,10 +952,19 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
- train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts
+ # strictly speaking, shuffled training cuts should be used instead
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py
index f9f80632e..fe8c85f61 100755
--- a/egs/librispeech/ASR/conformer_mmi/train.py
+++ b/egs/librispeech/ASR/conformer_mmi/train.py
@@ -771,10 +771,20 @@ def run(rank, world_size, args):
valid_ali = None
librispeech = LibriSpeechAsrDataModule(args)
- train_cuts = librispeech.train_clean_100_cuts()
+
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
index 6ef4c9860..2c1cef3a3 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py
@@ -989,10 +989,19 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
- train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py
index a4899f7bd..931341cc4 100755
--- a/egs/librispeech/ASR/pruned2_knowledge/train.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/train.py
@@ -817,10 +817,19 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
- train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
index 2d915ff87..e1bdce49d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -1038,13 +1038,26 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
+ assert not (
+ params.mini_libri and params.full_libri
+ ), f"Cannot set both mini-libri and full-libri flags to True, now mini-libri {params.mini_libri} and full-libri {params.full_libri}"
+
if params.mini_libri:
train_cuts = librispeech.train_clean_5_cuts()
else:
- train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
index 565dc7a16..1642ef4b7 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
@@ -1150,10 +1150,15 @@ def run(rank, world_size, args):
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
- train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py
index 7009f3346..3ccf7d2f1 100755
--- a/egs/librispeech/ASR/zipformer/train.py
+++ b/egs/librispeech/ASR/zipformer/train.py
@@ -1174,10 +1174,19 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
- train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- train_cuts += librispeech.train_clean_360_cuts()
- train_cuts += librispeech.train_other_500_cuts()
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
index 4b50acdde..dd8949523 100755
--- a/egs/librispeech/ASR/zipformer_mmi/train.py
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -990,11 +990,13 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
- # train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
- # train_cuts += librispeech.train_clean_360_cuts()
- # train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
else:
train_cuts = librispeech.train_clean_100_cuts()
From e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 Mon Sep 17 00:00:00 2001
From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com>
Date: Tue, 9 Jan 2024 15:41:37 +0800
Subject: [PATCH 034/123] fix typo (#1455)
---
.../RNN-LM/librispeech/lm-training.rst | 27 +++++++------------
1 file changed, 9 insertions(+), 18 deletions(-)
diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst
index 46499a374..e0c90f2a6 100644
--- a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst
+++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst
@@ -4,7 +4,7 @@ Train an RNN language model
======================================
If you have enough text data, you can train a neural network language model (NNLM) to improve
-the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from
+the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from
scratch.
.. HINT::
@@ -15,23 +15,23 @@ scratch.
.. note::
This tutorial is based on the LibriSpeech recipe. Please check it out for the necessary
- python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set
+ python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set
for illustration purpose. You can also collect your own data. The data format is quite simple:
each line should contain a complete sentence, and words should be separated by space.
-First, let's download the training data for the RNNLM. This can be done via the
+First, let's download the training data for the RNNLM. This can be done via the
following command:
.. code-block:: bash
- $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
+ $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz
$ gzip -d librispeech-lm-norm.txt.gz
As we are training a BPE-level RNNLM, we need to tokenize the training text, which requires a
BPE tokenizer. This can be achieved by executing the following command:
.. code-block:: bash
-
+
$ # if you don't have the BPE
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
$ cd icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500
@@ -56,11 +56,11 @@ sentence length.
--out-statistics data/lang_bpe_500/lm_data_stats.txt
-The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say
-you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt``
+The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say
+you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt``
and ``--lm-archive data/lang_bpe_500/lm-data-valid.pt`` when calling ``./local/prepare_lm_training_data.py``.
-After completing the previous steps, the training and testing sets for training RNNLM are ready.
+After completing the previous steps, the training and testing sets for training RNNLM are ready.
The next step is to train the RNNLM model. The training command is as follows:
.. code-block:: bash
@@ -77,7 +77,7 @@ The next step is to train the RNNLM model. The training command is as follows:
--use-fp16 0 \
--tie-weights 1 \
--embedding-dim 2048 \
- --hidden_dim 2048 \
+ --hidden-dim 2048 \
--num-layers 3 \
--batch-size 300 \
--lm-data rnn_lm/data/lang_bpe_500/sorted_lm_data.pt \
@@ -93,12 +93,3 @@ The next step is to train the RNNLM model. The training command is as follows:
.. note::
The training of RNNLM can take a long time (usually a couple of days).
-
-
-
-
-
-
-
-
-
From 398401ed277d4f895f624a95919c57edbbde4cba Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sun, 14 Jan 2024 14:38:41 +0800
Subject: [PATCH 035/123] Update kaldifeat installation doc (#1460)
---
docs/source/for-dummies/environment-setup.rst | 4 ++--
docs/source/for-dummies/model-export.rst | 6 +++---
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst
index 0cb8ecc1d..a68e9d3ed 100644
--- a/docs/source/for-dummies/environment-setup.rst
+++ b/docs/source/for-dummies/environment-setup.rst
@@ -66,13 +66,13 @@ to install dependencies of `icefall`_:
pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- # If you are using macOS or Windows, please use the following command to install torch and torchaudio
+ # If you are using macOS, please use the following command to install torch and torchaudio
# pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html
# Now install k2
# Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example
- pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
+ pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
# Install the latest version of lhotse
diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst
index 079ebc712..352a0dc90 100644
--- a/docs/source/for-dummies/model-export.rst
+++ b/docs/source/for-dummies/model-export.rst
@@ -85,7 +85,7 @@ We can also use it to decode files with the following command:
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
# for how to install kaldifeat
- pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
+ pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
./tdnn/pretrained.py \
--checkpoint ./tdnn/exp/pretrained.pt \
@@ -162,7 +162,7 @@ To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use:
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
# for how to install kaldifeat
- pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
+ pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
./tdnn/jit_pretrained.py \
@@ -249,7 +249,7 @@ To use the generated ONNX model files for decoding with `onnxruntime`_, we can u
# Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html
# for how to install kaldifeat
- pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
+ pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html
./tdnn/onnx_pretrained.py \
--nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \
From 7bdde9174c7c95a32a10d6dcbc3764ecb4873b1d Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Tue, 16 Jan 2024 21:08:35 +0800
Subject: [PATCH 036/123] A Zipformer recipe with Byte-level BPE for Aishell-1
(#1464)
* init commit
* Update train.py
* Update decode.py
* Update RESULTS.md
* added `vocab_size`
* removed unused softlinks
* added scripts for testing pretrained models
* set `bpe_model` as required
* re-org the bbpe recipe for aishell
---
egs/aishell/ASR/RESULTS.md | 56 +-
egs/aishell/ASR/zipformer/decode_bbpe.py | 840 ++++++++++++++++
.../ASR/zipformer/jit_pretrained_bbpe.py | 279 ++++++
egs/aishell/ASR/zipformer/pretrained_bbpe.py | 403 ++++++++
egs/aishell/ASR/zipformer/train_bbpe.py | 942 ++++++++++++++++++
5 files changed, 2518 insertions(+), 2 deletions(-)
create mode 100755 egs/aishell/ASR/zipformer/decode_bbpe.py
create mode 100755 egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py
create mode 100755 egs/aishell/ASR/zipformer/pretrained_bbpe.py
create mode 100755 egs/aishell/ASR/zipformer/train_bbpe.py
diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index 0b22f41a1..ff9504274 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -2,9 +2,61 @@
### Aishell training result (Stateless Transducer)
+#### Zipformer (Byte-level BPE)
+
+[./zipformer](./zipformer/)
+
+It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500.
+
+##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M
+
+| | test | dev | comment |
+|------------------------|------|------|-----------------------------------------|
+| greedy search | 4.54 | 4.31 | --epoch 40 --avg 10 |
+| modified beam search | 4.37 | 4.11 | --epoch 40 --avg 10 |
+| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
+
+```bash
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./zipformer/train_bbpe.py \
+ --world-size 2 \
+ --num-epochs 40 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --context-size 2 \
+ --enable-musan 0 \
+ --exp-dir zipformer/exp_bbpe \
+ --max-duration 1000 \
+ --enable-musan 0 \
+ --base-lr 0.045 \
+ --lr-batches 7500 \
+ --lr-epochs 10 \
+ --spec-aug-time-warp-factor 20
+```
+
+Command for decoding is:
+```bash
+for m in greedy_search modified_beam_search fast_beam_search ; do
+ ./zipformer/decode_bbpe.py \
+ --epoch 40 \
+ --avg 10 \
+ --exp-dir ./zipformer_bbpe/exp \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --context-size 2 \
+ --decoding-method $m
+done
+```
+Pretrained models, training logs, decoding logs, tensorboard and decoding results
+are available at
+
+
+
#### Zipformer (Non-streaming)
-[./zipformer](./zipformer)
+[./zipformer](./zipformer/)
It's reworked Zipformer with Pruned RNNT loss.
**Caution**: It uses `--context-size=1`.
@@ -260,7 +312,7 @@ done
Pretrained models, training logs, decoding logs, and decoding results
are available at
-#### Pruned transducer stateless 7 (zipformer)
+#### Pruned transducer stateless 7 (Byte-level BPE)
See
diff --git a/egs/aishell/ASR/zipformer/decode_bbpe.py b/egs/aishell/ASR/zipformer/decode_bbpe.py
new file mode 100755
index 000000000..1ec10b059
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/decode_bbpe.py
@@ -0,0 +1,840 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Mingshuang Luo,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(3) fast beam search (trivial_graph)
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(4) fast beam search (LG)
+./zipformer/decode_bbpe.py \
+ --epoch 30 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest oracle WER)
+./zipformer/decode_bbpe.py \
+ --epoch 35 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp_bbpe \
+ --lang-dir data/lang_bbpe_500 \
+ --bpe-model data/lang_bbpe_500/bbpe.model \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from train import add_model_arguments, get_model, get_params
+
+from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_bbpe/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the byte BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bbpe_500/",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_LG
+ - fast_beam_search_nbest_oracle
+ If you use fast_beam_search_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--ilme-scale",
+ type=float,
+ default=0.2,
+ help="""
+ Used only when --decoding_method is fast_beam_search_LG.
+ It specifies the scale for the internal language model estimation.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search, fast_beam_search_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ lexicon: Lexicon,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ x, x_lens = model.encoder_embed(feature, feature_lens)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_LG":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ blank_penalty=params.blank_penalty,
+ ilme_scale=params.ilme_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([lexicon.word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ ref_texts = []
+ for tx in supervisions["text"]:
+ ref_texts.append(byte_encode(tokenize_by_CJK_char(tx)))
+
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(ref_texts),
+ nbest_scale=params.nbest_scale,
+ blank_penalty=params.blank_penalty,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ blank_penalty=params.blank_penalty,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ blank_penalty=params.blank_penalty,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ blank_penalty=params.blank_penalty,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+
+ key = f"blank_penalty_{params.blank_penalty}"
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search_" + key: hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key += f"_beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ilme_scale_{params.ilme_scale}"
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}_" + key: hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ lexicon: Lexicon,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ lexicon:
+ directory containing the lexicon.
+ sp:
+ SentencePiece model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ lexicon=lexicon,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = "".join(ref_text.split())
+
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "modified_beam_search",
+ "fast_beam_search",
+ "fast_beam_search_LG",
+ "fast_beam_search_nbest_oracle",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"_ilme_scale_{params.ilme_scale}"
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ lexicon = Lexicon(params.lang_dir)
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if "LG" in params.decoding_method:
+ lexicon = Lexicon(params.lang_dir)
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ aishell = AishellAsrDataModule(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
+ )
+ return T > 0
+
+ dev_cuts = aishell.valid_cuts()
+ dev_cuts = dev_cuts.filter(remove_short_utt)
+ dev_dl = aishell.valid_dataloaders(dev_cuts)
+
+ test_cuts = aishell.test_cuts()
+ test_cuts = test_cuts.filter(remove_short_utt)
+ test_dl = aishell.test_dataloaders(test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dls = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ lexicon=lexicon,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py
new file mode 100755
index 000000000..cd16284f7
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads torchscript models, exported by `torch.jit.script()`
+and uses them to decode waves.
+You can use the following command to get the exported models:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer_bbpe/exp \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+Usage of this script:
+
+./zipformer/jit_pretrained.py \
+ --nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+"""
+
+import argparse
+import logging
+import math
+from typing import List
+
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from torch.nn.utils.rnn import pad_sequence
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--nn-model-filename",
+ type=str,
+ required=True,
+ help="Path to the torchscript model cpu_jit.pt",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ required=True,
+ help="""Path to the bbpe.model.""",
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float = 16000
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+def greedy_search(
+ model: torch.jit.ScriptModule,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+) -> List[List[int]]:
+ """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+ Args:
+ model:
+ The transducer model.
+ encoder_out:
+ A 3-D tensor of shape (N, T, C)
+ encoder_out_lens:
+ A 1-D tensor of shape (N,).
+ Returns:
+ Return the decoded results for each utterance.
+ """
+ assert encoder_out.ndim == 3
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ device = encoder_out.device
+ blank_id = model.decoder.blank_id
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ context_size = model.decoder.context_size
+ hyps = [[blank_id] * context_size for _ in range(N)]
+
+ decoder_input = torch.tensor(
+ hyps,
+ device=device,
+ dtype=torch.int64,
+ ) # (N, context_size)
+
+ decoder_out = model.decoder(
+ decoder_input,
+ need_pad=torch.tensor([False]),
+ ).squeeze(1)
+
+ offset = 0
+ for batch_size in batch_size_list:
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = packed_encoder_out.data[start:end]
+ current_encoder_out = current_encoder_out
+ # current_encoder_out's shape: (batch_size, encoder_out_dim)
+ offset = end
+
+ decoder_out = decoder_out[:batch_size]
+
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out,
+ )
+ # logits'shape (batch_size, vocab_size)
+
+ assert logits.ndim == 2, logits.shape
+ y = logits.argmax(dim=1).tolist()
+ emitted = False
+ for i, v in enumerate(y):
+ if v != blank_id:
+ hyps[i].append(v)
+ emitted = True
+ if emitted:
+ # update decoder output
+ decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+ decoder_input = torch.tensor(
+ decoder_input,
+ device=device,
+ dtype=torch.int64,
+ )
+ decoder_out = model.decoder(
+ decoder_input,
+ need_pad=torch.tensor([False]),
+ )
+ decoder_out = decoder_out.squeeze(1)
+
+ sorted_ans = [h[context_size:] for h in hyps]
+ ans = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ model = torch.jit.load(args.nn_model_filename)
+
+ model.eval()
+
+ model.to(device)
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(args.bpe_model)
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {args.sound_files}")
+ waves = read_sound_files(
+ filenames=args.sound_files,
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(
+ features,
+ batch_first=True,
+ padding_value=math.log(1e-10),
+ )
+
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ encoder_out, encoder_out_lens = model.encoder(
+ features=features,
+ feature_lengths=feature_lengths,
+ )
+
+ hyps = greedy_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+
+ s = "\n"
+ for filename, hyp in zip(args.sound_files, hyps):
+ words = smart_byte_decode(sp.decode(hyp))
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/zipformer/pretrained_bbpe.py b/egs/aishell/ASR/zipformer/pretrained_bbpe.py
new file mode 100755
index 000000000..387bef98a
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/pretrained_bbpe.py
@@ -0,0 +1,403 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp_bbpe \
+ --tokens ./data/lang_bbpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp_bbpe \
+ --causal 1 \
+ --tokens ./data/lang_bbpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+Usage of this script:
+
+- For non-streaming model:
+
+(1) greedy search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+- For streaming model:
+
+(1) greedy search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained_bbpe.py \
+ --checkpoint ./zipformer/exp_bbpe/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --bpe ./data/lang_bbpe_500/bbpe.model \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+You can also use `./zipformer/exp_bbpe/epoch-xx.pt`.
+
+Note: ./zipformer/exp_bbpe/pretrained.pt is generated by ./zipformer/export_bbpe.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ required=True,
+ help="""Path to the bbpe.model.""",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+
+ logging.info("Creating model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+ opts.mel_opts.high_freq = -400
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ # model forward
+ encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
+
+ num_waves = encoder_out.size(0)
+ hyps = []
+ msg = f"Using {params.method}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ for i in range(num_waves):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ words = " ".join(hyp)
+ s += f"{filename}:\n{words}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py
new file mode 100755
index 000000000..a2bf96b29
--- /dev/null
+++ b/egs/aishell/ASR/zipformer/train_bbpe.py
@@ -0,0 +1,942 @@
+#!/usr/bin/env python3
+# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+./zipformer/train_bbpe.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp_bbpe \
+ --max-duration 350
+
+# For mix precision training:
+
+./zipformer/train_bbpe.py \
+ --world-size 8 \
+ --num-epochs 12 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp_bbpe \
+ --max-duration 750
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from train import (
+ LRSchedulerType,
+ add_model_arguments,
+ get_adjusted_batch_count,
+ get_model,
+ get_params,
+ load_checkpoint_if_available,
+ save_checkpoint,
+ set_batch_count,
+)
+
+from icefall import byte_encode, diagnostics
+from icefall.checkpoint import remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+ tokenize_by_CJK_char,
+)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_bbpe/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_500/bbpe.model",
+ help="Path to the Byte BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="""The prune range for rnnt loss, it means how many symbols(context)
+ we are using to compute the loss""",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="""The scale to smooth the loss with lm
+ (output of prediction network) part.""",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="""The scale to smooth the loss with am (output of encoder network) part.""",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="""To get pruning ranges, we will calculate a simple version
+ loss(joiner is just addition), this simple loss also uses for
+ training (as a regularization item). We will scale the simple loss
+ with this parameter before adding to the final loss.""",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute CTC loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, _ = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bbpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ train_cuts = aishell.train_cuts()
+ valid_cuts = aishell.valid_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 15 seconds
+ #
+ # Caution: There is a reason to select 15.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 15.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ def tokenize_and_encode_text(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = byte_encode(tokenize_by_CJK_char(text))
+ c.supervisions[0].text = text
+ return c
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ train_cuts = train_cuts.map(tokenize_and_encode_text)
+
+ valid_cuts = valid_cuts.map(tokenize_and_encode_text)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_dl = aishell.valid_dataloaders(valid_cuts)
+
+ if False and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The sentence piece model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
From 5dfc3ed7f995f30a4b64e57071a095566f8126ea Mon Sep 17 00:00:00 2001
From: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
Date: Sun, 21 Jan 2024 02:10:42 +0800
Subject: [PATCH 037/123] Fix buffer size of DynamicBucketingSampler (#1468)
* Fix buffer size
* Fix for flake8
---------
Co-authored-by: yifanyeung
---
.../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++-
egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++
.../ASR/transducer_stateless_modified-2/asr_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless5/asr_datamodule.py | 3 ++-
.../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++-
.../ASR_v2/pruned_transducer_stateless7/asr_datamodule.py | 2 ++
egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 ++
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 ++
.../commonvoice_fr.py | 2 ++
egs/csj/ASR/local/utils/asr_datamodule.py | 2 ++
egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 ++
egs/gigaspeech/ASR/zipformer/asr_datamodule.py | 2 ++
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py | 2 ++
egs/libriheavy/ASR/zipformer/asr_datamodule.py | 2 ++
egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py | 2 ++
egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless3/asr_datamodule.py | 6 ++++++
.../ASR/pruned_transducer_stateless7/gigaspeech.py | 2 ++
egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++
egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py | 2 ++
egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py | 2 ++
egs/ljspeech/TTS/vits/tts_datamodule.py | 2 ++
egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 2 ++
egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 2 ++
egs/multi_zh_en/ASR/zipformer/asr_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 ++
egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 3 ++-
.../ASR/pruned_transducer_stateless5/asr_datamodule.py | 3 ++-
egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py | 2 ++
egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++
egs/vctk/TTS/vits/tts_datamodule.py | 2 ++
.../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++-
.../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 ++
egs/yesno/ASR/tdnn/asr_datamodule.py | 2 ++
37 files changed, 78 insertions(+), 6 deletions(-)
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
index d491996b2..e29dd8ab5 100644
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -288,8 +288,9 @@ class Aidatatang_200zhAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
- buffer_size=50000,
)
else:
logging.info("Using SimpleCutSampler.")
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 6abe6c084..aacbd153d 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -275,6 +275,8 @@ class AishellAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
index cd8dd821c..ed453afd2 100644
--- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py
@@ -226,6 +226,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 8f6a88f59..f9cdfb621 100644
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -296,6 +296,8 @@ class AiShell2AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
index e6db2651f..c10456da5 100644
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -306,7 +306,8 @@ class Aishell4AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=100000,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 5ad80817a..410741215 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -288,7 +288,8 @@ class AlimeetingAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=30000,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
index 9d288218a..6b56c8a6a 100644
--- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py
@@ -263,6 +263,8 @@ class AlimeetingAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
logging.info("About to create train dataloader")
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
index 79474f1d8..554facfc1 100644
--- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
+++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -269,6 +269,8 @@ class AmiAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
logging.info("About to create train dataloader")
diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
index 1549c1631..ea8b62242 100644
--- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
+++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
@@ -254,6 +254,8 @@ class AmiAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
index 546e9f9dd..c40d9419b 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py
@@ -308,6 +308,8 @@ class CommonVoiceAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
index cafa4111d..79cf86b84 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py
@@ -310,6 +310,8 @@ class CommonVoiceAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py
index 042b6ecbf..7bf7bdef0 100644
--- a/egs/csj/ASR/local/utils/asr_datamodule.py
+++ b/egs/csj/ASR/local/utils/asr_datamodule.py
@@ -336,6 +336,8 @@ class CSJAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
index a93e224d5..569978424 100644
--- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py
@@ -261,6 +261,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index b5b27ce95..40339365c 100644
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
index 6adfdbfbb..850ab7c10 100644
--- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
@@ -311,6 +311,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
index c1abdbdb5..500df9ea4 100644
--- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
+++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
@@ -256,6 +256,8 @@ class LibriCssAsrDataModule:
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
index df761c1b8..e23c9b1b7 100644
--- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py
+++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
@@ -310,6 +310,8 @@ class LibriHeavyAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
index 690003377..1a4c9a532 100644
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py
@@ -341,6 +341,8 @@ class LibriHeavyAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
index ee7556e49..be36c06b6 100644
--- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
+++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py
@@ -286,6 +286,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
bucket_method="equal_duration",
drop_last=True,
)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
index 057624272..87c62789e 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py
@@ -223,6 +223,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
@@ -256,6 +258,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=False,
)
logging.info("About to create dev dataloader")
@@ -282,6 +286,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py
index cd432fd6f..306f30c2f 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py
@@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index c500eb3e5..dd9e9ef1f 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -311,6 +311,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
index 3acd22ae4..84bd3fc4b 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
@@ -304,6 +304,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
index 2f8e658c5..e1a29bd9c 100644
--- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
+++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py
@@ -227,6 +227,8 @@ class LibriSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
index 81bb9ed13..8ff868bc8 100644
--- a/egs/ljspeech/TTS/vits/tts_datamodule.py
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -196,6 +196,8 @@ class LJSpeechTtsDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
index 7753d1674..48921d71f 100644
--- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
+++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py
@@ -266,6 +266,8 @@ class MGB2AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py
index 02cfa1346..341579acb 100644
--- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py
+++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py
@@ -297,6 +297,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
index be6e94472..662ae01c5 100644
--- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
+++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
@@ -294,6 +294,8 @@ class AsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index cf70fc0f8..7cd6771ce 100644
--- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -236,6 +236,8 @@ class SPGISpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
logging.info("About to create train dataloader")
diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py
index ce8634a1d..0f6f02e8d 100644
--- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py
+++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py
@@ -298,8 +298,9 @@ class SwitchBoardAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
- buffer_size=50000,
)
else:
logging.info("Using SimpleCutSampler.")
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 5269a1778..6f0833db6 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -306,8 +306,9 @@ class TAL_CSASRAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
num_cuts_for_bins_estimate=20000,
- buffer_size=60000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
index d4a9e4bc9..a67cf8d04 100644
--- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
+++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py
@@ -256,6 +256,8 @@ class TedLiumAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 5d1b3c367..8606a490b 100644
--- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -222,6 +222,8 @@ class TimitAsrDataModule(DataModule):
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py
index 8b2a96b09..52fc5179f 100644
--- a/egs/vctk/TTS/vits/tts_datamodule.py
+++ b/egs/vctk/TTS/vits/tts_datamodule.py
@@ -204,6 +204,8 @@ class VctkTtsDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
index 1dbfb9709..58da1d68c 100644
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
@@ -292,7 +292,8 @@ class WenetSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=300000,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 7594fb28e..7b37b1331 100644
--- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -296,6 +296,8 @@ class Xbmu_AmdoAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py
index dc66b217d..b9ce8fb4e 100644
--- a/egs/yesno/ASR/tdnn/asr_datamodule.py
+++ b/egs/yesno/ASR/tdnn/asr_datamodule.py
@@ -193,6 +193,8 @@ class YesNoAsrDataModule(DataModule):
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=True,
)
else:
From ebe97a07b082f9bee4a18b5f8e54c453187a74bb Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Tue, 23 Jan 2024 16:26:24 +0800
Subject: [PATCH 038/123] Reworked README.md (#1470)
* Rework README.md
Co-authored-by: Fangjun Kuang
---------
Co-authored-by: Fangjun Kuang
---
README.md | 439 +++++++++++++++++++++++++-----------------------------
1 file changed, 201 insertions(+), 238 deletions(-)
diff --git a/README.md b/README.md
index 15e9e17e6..61920be65 100644
--- a/README.md
+++ b/README.md
@@ -2,46 +2,83 @@
-## Introduction
+# Introduction
-icefall contains ASR recipes for various datasets
-using .
+The icefall peoject contains speech related recipes for various datasets
+using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse).
-You can use to deploy models
-trained with icefall.
+You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models
+in icefall; these frameworks also support models not included in icefall; please refer to respective documents for more details.
You can try pre-trained models from within your browser without the need
-to download or install anything by visiting
-See for more details.
+to download or install anything by visiting this [huggingface space](https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition).
+Please refer to [document](https://k2-fsa.github.io/icefall/huggingface/spaces.html) for more details.
-## Installation
+# Installation
-Please refer to
+Please refer to [document](https://icefall.readthedocs.io/en/latest/installation/index.html)
for installation.
-## Recipes
+# Recipes
-Please refer to
-for more information.
+Please refer to [document](https://icefall.readthedocs.io/en/latest/recipes/index.html)
+for more details.
-We provide the following recipes:
+## ASR: Automatic Speech Recognition
+### Supported Datasets
- [yesno][yesno]
- - [LibriSpeech][librispeech]
- - [GigaSpeech][gigaspeech]
- - [AMI][ami]
+
+ - [Aidatatang_200zh][aidatatang_200zh]
- [Aishell][aishell]
- [Aishell2][aishell2]
- [Aishell4][aishell4]
+ - [Alimeeting][alimeeting]
+ - [AMI][ami]
+ - [CommonVoice][commonvoice]
+ - [Corpus of Spontaneous Japanese][csj]
+ - [GigaSpeech][gigaspeech]
+ - [LibriCSS][libricss]
+ - [LibriSpeech][librispeech]
+ - [Libriheavy][libriheavy]
+ - [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2]
+ - [PeopleSpeech][peoplespeech]
+ - [SPGISpeech][spgispeech]
+ - [Switchboard][swbd]
- [TIMIT][timit]
- [TED-LIUM3][tedlium3]
- - [Aidatatang_200zh][aidatatang_200zh]
- - [WenetSpeech][wenetspeech]
- - [Alimeeting][alimeeting]
- - [Switchboard][swbd]
- [TAL_CSASR][tal_csasr]
+ - [Voxpopuli][voxpopuli]
+ - [XBMU-AMDO31][xbmu-amdo31]
+ - [WenetSpeech][wenetspeech]
+
+More datasets will be added in the future.
-### yesno
+### Supported Models
+
+The [LibriSpeech][librispeech] recipe supports the most comprehensive set of models, you are welcome to try them out.
+
+#### CTC
+ - TDNN LSTM CTC
+ - Conformer CTC
+ - Zipformer CTC
+
+#### MMI
+ - Conformer MMI
+ - Zipformer MMI
+
+#### Transducer
+ - Conformer-based Encoder
+ - LSTM-based Encoder
+ - Zipformer-based Encoder
+ - LSTM-based Predictor
+ - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
+
+If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
+
+We would like to highlight the performance of some of the recipes here.
+
+### [yesno][yesno]
This is the simplest ASR recipe in `icefall` and can be run on CPU.
Training takes less than 30 seconds and gives you the following WER:
@@ -52,350 +89,264 @@ Training takes less than 30 seconds and gives you the following WER:
We provide a Colab notebook for this recipe: [](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing)
-### LibriSpeech
+### [LibriSpeech][librispeech]
-Please see
+Please see [RESULTS.md](https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md)
for the **latest** results.
-We provide 5 models for this recipe:
-
-- [conformer CTC model][LibriSpeech_conformer_ctc]
-- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc]
-- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer]
-- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless]
-- [Transducer: Zipformer encoder + Embedding decoder][LibriSpeech_zipformer]
-
-#### Conformer CTC Model
-
-The best WER we currently have is:
+#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc)
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.42 | 5.73 |
-We provide a Colab notebook to run a pre-trained conformer CTC model: [](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
-#### TDNN LSTM CTC Model
-
-The WER for this model is:
+#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc)
| | test-clean | test-other |
|-----|------------|------------|
| WER | 6.59 | 17.69 |
-We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing)
-#### Transducer: Conformer encoder + LSTM decoder
+#### [Transducer (Conformer Encoder + LSTM Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer)
-Using Conformer as encoder and LSTM as decoder.
+| | test-clean | test-other |
+|---------------|------------|------------|
+| greedy search | 3.07 | 7.51 |
-The best WER with greedy search is:
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 3.07 | 7.51 |
+#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer)
-We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
-
-#### Transducer: Conformer encoder + Embedding decoder
-
-Using Conformer as encoder. The decoder consists of 1 embedding layer
-and 1 convolutional layer.
-
-The best WER using modified beam search with beam size 4 is:
-
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 2.56 | 6.27 |
-
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
-
-We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
+| | test-clean | test-other |
+|---------------------------------------|------------|------------|
+| modified_beam_search (`beam_size=4`) | 2.56 | 6.27 |
-#### k2 pruned RNN-T
+We provide a Colab notebook to run test the pre-trained model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
+
+
+#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer)
+
+WER (modified_beam_search `beam_size=4` unless further stated)
+
+1. LibriSpeech-960hr
| Encoder | Params | test-clean | test-other | epochs | devices |
|-----------------|--------|------------|------------|---------|------------|
-| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 |
-| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 |
-| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 |
-| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 |
+| Zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 |
+| Zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 |
+| Zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 |
+| Zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 |
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
+2. LibriSpeech-960hr + GigaSpeech
-#### k2 pruned RNN-T + GigaSpeech
-
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 1.78 | 4.08 |
-
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
-
-#### k2 pruned RNN-T + GigaSpeech + CommonVoice
-
-| | test-clean | test-other |
-|-----|------------|------------|
-| WER | 1.90 | 3.98 |
-
-Note: No auxiliary losses are used in the training and no LMs are used
-in the decoding.
+| Encoder | Params | test-clean | test-other |
+|-----------------|--------|------------|------------|
+| Zipformer | 65.5M | 1.78 | 4.08 |
-### GigaSpeech
+3. LibriSpeech-960hr + GigaSpeech + CommonVoice
-We provide three models for this recipe:
+| Encoder | Params | test-clean | test-other |
+|-----------------|--------|------------|------------|
+| Zipformer | 65.5M | 1.90 | 3.98 |
-- [Conformer CTC model][GigaSpeech_conformer_ctc]
-- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2].
-- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer]
-#### Conformer CTC
+### [GigaSpeech][gigaspeech]
+
+#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/conformer_ctc)
| | Dev | Test |
|-----|-------|-------|
| WER | 10.47 | 10.58 |
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/pruned_transducer_stateless2)
+
+Conformer Encoder + Stateless Predictor + k2 Pruned RNN-T Loss
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 10.51 | 10.73 |
-| fast beam search | 10.50 | 10.69 |
-| modified beam search | 10.40 | 10.51 |
+| greedy_search | 10.51 | 10.73 |
+| fast_beam_search | 10.50 | 10.69 |
+| modified_beam_search | 10.40 | 10.51 |
-#### Transducer: Zipformer encoder + Embedding decoder
+#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/zipformer)
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 10.31 | 10.50 |
-| fast beam search | 10.26 | 10.48 |
-| modified beam search | 10.25 | 10.38 |
+| greedy_search | 10.31 | 10.50 |
+| fast_beam_search | 10.26 | 10.48 |
+| modified_beam_search | 10.25 | 10.38 |
-### Aishell
+### [Aishell][aishell]
-We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc],
-[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7],
-
-#### Conformer CTC Model
-
-The best CER we currently have is:
-
-| | test |
-|-----|------|
-| CER | 4.26 |
-
-#### TDNN LSTM CTC Model
-
-The CER for this model is:
+#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/tdnn_lstm_ctc)
| | test |
|-----|-------|
| CER | 10.16 |
-We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing)
-#### Transducer Stateless Model
-
-The best CER we currently have is:
+#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/transducer_stateless)
| | test |
|-----|------|
| CER | 4.38 |
-We provide a Colab notebook to run a pre-trained TransducerStateless model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing)
+
+#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/zipformer)
+
+WER (modified_beam_search `beam_size=4`)
+
+| Encoder | Params | dev | test | epochs |
+|-----------------|--------|-----|------|---------|
+| Zipformer | 73.4M | 4.13| 4.40 | 55 |
+| Zipformer-small | 30.2M | 4.40| 4.67 | 55 |
+| Zipformer-large | 157.3M | 4.03| 4.28 | 56 |
-### Aishell2
+### [Aishell4][aishell4]
-We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5].
-
-#### Transducer Stateless Model
-
-The best WER we currently have is:
-
-| | dev-ios | test-ios |
-|-----|------------|------------|
-| WER | 5.32 | 5.56 |
-
-
-### Aishell4
-
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
-
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
-
-The best CER we currently have is:
+#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell4/ASR/pruned_transducer_stateless5)
+1 Trained with all subsets:
| | test |
|-----|------------|
| CER | 29.08 |
-
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
-### TIMIT
+### [TIMIT][timit]
-We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc]
-and [TDNN LiGRU CTC model][TIMIT_tdnn_ligru_ctc].
+#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_lstm_ctc)
-#### TDNN LSTM CTC Model
-
-The best PER we currently have is:
-
-||TEST|
-|--|--|
+| |TEST|
+|---|----|
|PER| 19.71% |
-We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing)
-#### TDNN LiGRU CTC Model
+#### [TDNN LiGRU CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_ligru_ctc)
-The PER for this model is:
-
-||TEST|
-|--|--|
+| |TEST|
+|---|----|
|PER| 17.66% |
-We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
-### TED-LIUM3
+### [TED-LIUM3][tedlium3]
-We provide two models for this recipe: [Transducer Stateless: Conformer encoder + Embedding decoder][TED-LIUM3_transducer_stateless] and [Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TED-LIUM3_pruned_transducer_stateless].
+#### [Transducer (Conformer Encoder + Embedding Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless)
-#### Transducer Stateless: Conformer encoder + Embedding decoder
-
-The best WER using modified beam search with beam size 4 is:
-
-| | dev | test |
-|-----|-------|--------|
-| WER | 6.91 | 6.33 |
-
-Note: No auxiliary losses are used in the training and no LMs are used in the decoding.
-
-We provide a Colab notebook to run a pre-trained Transducer Stateless model: [](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing)
-
-#### Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
-
-The best WER using modified beam search with beam size 4 is:
-
-| | dev | test |
-|-----|-------|--------|
-| WER | 6.77 | 6.14 |
-
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
+| | dev | test |
+|--------------------------------------|-------|--------|
+| modified_beam_search (`beam_size=4`) | 6.91 | 6.33 |
-### Aidatatang_200zh
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing)
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aidatatang_200zh_pruned_transducer_stateless2].
+#### [Transducer (pruned_transducer_stateless)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/pruned_transducer_stateless)
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+| | dev | test |
+|--------------------------------------|-------|--------|
+| modified_beam_search (`beam_size=4`) | 6.77 | 6.14 |
+
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing)
+
+
+### [Aidatatang_200zh][aidatatang_200zh]
+
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2)
| | Dev | Test |
|----------------------|-------|-------|
-| greedy search | 5.53 | 6.59 |
-| fast beam search | 5.30 | 6.34 |
-| modified beam search | 5.27 | 6.33 |
+| greedy_search | 5.53 | 6.59 |
+| fast_beam_search | 5.30 | 6.34 |
+| modified_beam_search | 5.27 | 6.33 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing)
-### WenetSpeech
+### [WenetSpeech][wenetspeech]
-We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
-
-#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR)
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless2)
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
-| greedy search | 7.80 | 8.75 | 13.49 |
-| modified beam search| 7.76 | 8.71 | 13.41 |
-| fast beam search | 7.94 | 8.74 | 13.80 |
+| greedy_search | 7.80 | 8.75 | 13.49 |
+| fast_beam_search | 7.94 | 8.74 | 13.80 |
+| modified_beam_search | 7.76 | 8.71 | 13.41 |
+
+We provide a Colab notebook to run the pre-trained model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
+
+#### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5)
-#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
-**Streaming**:
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
| greedy_search | 8.78 | 10.12 | 16.16 |
-| modified_beam_search | 8.53| 9.95 | 15.81 |
| fast_beam_search| 9.01 | 10.47 | 16.28 |
+| modified_beam_search | 8.53| 9.95 | 15.81 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
-### Alimeeting
+### [Alimeeting][alimeeting]
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Alimeeting_pruned_transducer_stateless2].
-
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with far subset)
+#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/alimeeting/ASR/pruned_transducer_stateless2)
| | Eval | Test-Net |
|----------------------|--------|----------|
-| greedy search | 31.77 | 34.66 |
-| fast beam search | 31.39 | 33.02 |
-| modified beam search | 30.38 | 34.25 |
+| greedy_search | 31.77 | 34.66 |
+| fast_beam_search | 31.39 | 33.02 |
+| modified_beam_search | 30.38 | 34.25 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
-### TAL_CSASR
+### [TAL_CSASR][tal_csasr]
-We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
-#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/tal_csasr/ASR/pruned_transducer_stateless5)
The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English):
|decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en |
|--|--|--|--|--|--|--|
|greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13|
-|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
|fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77|
+|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 |
-We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
-## Deployment with C++
+## TTS: Text-to-Speech
-Once you have trained a model in icefall, you may want to deploy it with C++,
-without Python dependencies.
+### Supported Datasets
-Please refer to the documentation
-
+ - [LJSpeech][ljspeech]
+ - [VCTK][vctk]
+
+### Supported Models
+
+ - [VITS](https://arxiv.org/abs/2106.06103)
+
+# Deployment with C++
+
+Once you have trained a model in icefall, you may want to deploy it with C++ without Python dependencies.
+
+Please refer to the [document](https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c)
for how to do this.
We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++.
Please see: [](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing)
-[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
-[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
-[LibriSpeech_transducer]: egs/librispeech/ASR/transducer
-[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless
-[LibriSpeech_zipformer]: egs/librispeech/ASR/zipformer
-[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
-[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
-[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe
-[Aishell2_pruned_transducer_stateless5]: egs/aishell2/ASR/pruned_transducer_stateless5
-[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
-[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc
-[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc
-[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless
-[TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless
-[GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc
-[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2
-[GigaSpeech_zipformer]: egs/gigaspeech/ASR/zipformer
-[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
-[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
-[WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5
-[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
-[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
@@ -411,3 +362,15 @@ Please see: [
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 61920be65..f92c85ad4 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
# Introduction
-The icefall peoject contains speech related recipes for various datasets
+The icefall project contains speech-related recipes for various datasets
using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse).
You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models
@@ -373,4 +373,4 @@ Please see: [
---
README.md | 12 ++++++------
.../ASR/local/compute_fbank_peoples_speech_splits.py | 4 ++--
.../ASR/local/compute_fbank_wenetspeech_splits.py | 4 ++--
3 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/README.md b/README.md
index f92c85ad4..cc817702b 100644
--- a/README.md
+++ b/README.md
@@ -116,7 +116,7 @@ We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
@@ -127,7 +127,7 @@ We provide a Colab notebook to test the pre-trained model: [ | 2.56 | 6.27 |
-We provide a Colab notebook to run test the pre-trained model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer)
@@ -147,14 +147,14 @@ WER (modified_beam_search `beam_size=4` unless further stated)
| Encoder | Params | test-clean | test-other |
|-----------------|--------|------------|------------|
-| Zipformer | 65.5M | 1.78 | 4.08 |
+| Zipformer | 65.5M | 1.78 | 4.08 |
3. LibriSpeech-960hr + GigaSpeech + CommonVoice
| Encoder | Params | test-clean | test-other |
|-----------------|--------|------------|------------|
-| Zipformer | 65.5M | 1.90 | 3.98 |
+| Zipformer | 65.5M | 1.90 | 3.98 |
### [GigaSpeech][gigaspeech]
@@ -246,7 +246,7 @@ We provide a Colab notebook to test the pre-trained model: [](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless)
+#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless)
| | dev | test |
|--------------------------------------|-------|--------|
@@ -287,7 +287,7 @@ We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
+We provide a Colab notebook to test the pre-trained model: [](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
#### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5)
diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py
index c2ab3d07d..6f05b9f8c 100755
--- a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py
+++ b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py
@@ -67,14 +67,14 @@ def get_args():
"--start",
type=int,
default=0,
- help="Process pieces starting from this number (inclusive).",
+ help="Process pieces starting from this number (included).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
- help="Stop processing pieces until this number (exclusive).",
+ help="Stop processing pieces until this number (excluded).",
)
return parser.parse_args()
diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index 99d39bbdc..a87801462 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -78,14 +78,14 @@ def get_parser():
"--start",
type=int,
default=0,
- help="Process pieces starting from this number (inclusive).",
+ help="Process pieces starting from this number (included).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
- help="Stop processing pieces until this number (exclusive).",
+ help="Stop processing pieces until this number (excluded).",
)
return parser
From c401a2646b347bf1fff0c2ce1a4ee13b0f482448 Mon Sep 17 00:00:00 2001
From: Zengwei Yao
Date: Fri, 26 Jan 2024 15:50:11 +0800
Subject: [PATCH 041/123] minor fix of zipformer/optim.py (#1474)
---
egs/librispeech/ASR/zipformer/optim.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py
index 714d8db9a..aaffbfed5 100644
--- a/egs/librispeech/ASR/zipformer/optim.py
+++ b/egs/librispeech/ASR/zipformer/optim.py
@@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer):
# case 2 or case 4
# the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups:
- assert "named_params" in cur_group
- name_list = [x[0] for x in cur_group["named_params"]]
- p_list = [x[1] for x in cur_group["named_params"]]
- del cur_group["named_params"]
- cur_group["params"] = p_list
+ if "named_params" in cur_group:
+ name_list = [x[0] for x in cur_group["named_params"]]
+ p_list = [x[1] for x in cur_group["named_params"]]
+ del cur_group["named_params"]
+ cur_group["params"] = p_list
+ else:
+ assert "params" in cur_group
+ name_list = ["foo" for _ in cur_group["params"]]
param_groups.append(cur_group)
param_groups_names.append(name_list)
From 8d39f9508bd5b27627165696758cad2e96dca20b Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Fri, 26 Jan 2024 19:18:33 +0800
Subject: [PATCH 042/123] Fix torchscript export to use tokens.txt instead of
lang_dir (#1475)
---
.../pruned_transducer_stateless2/export.py | 25 +++++++++------
.../ASR/pruned_transducer_stateless2/lstmp.py | 1 +
.../scaling_converter.py | 1 +
.../pruned_transducer_stateless2/export.py | 24 ++++++++------
.../ASR/pruned_transducer_stateless2/lstmp.py | 1 +
.../scaling_converter.py | 1 +
.../ASR/lstm_transducer_stateless2/export.py | 7 ++---
.../pruned_stateless_emformer_rnnt2/export.py | 1 +
.../pruned_transducer_stateless/decoder.py | 2 +-
.../export.py | 8 ++---
.../pruned_transducer_stateless5/export.py | 31 +++++++++----------
.../ASR/pruned_transducer_stateless5/lstmp.py | 1 +
.../scaling_converter.py | 1 +
.../pruned_transducer_stateless2/export.py | 23 +++++++-------
.../pruned_transducer_stateless5/export.py | 25 +++++++--------
15 files changed, 83 insertions(+), 69 deletions(-)
mode change 100644 => 100755 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py
create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py
create mode 120000 egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
create mode 120000 egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
old mode 100644
new mode 100755
index e348f7b2b..5179bfa1c
--- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@@ -20,7 +21,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 29 \
--avg 19
@@ -45,12 +46,13 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -85,10 +87,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -122,10 +124,14 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
+ # Load tokens.txt here
+ token_table = k2.SymbolTable.from_file(params.tokens)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ # Load id of the token and the vocab size
+ # is defined in local/train_bpe_model.py
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
@@ -152,6 +158,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
index b6190e8a6..4a44f7bcb 100755
--- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py
@@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@@ -47,12 +47,13 @@ import argparse
import logging
from pathlib import Path
-import sentencepiece as spm
+import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -98,10 +99,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--tokens",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -135,12 +136,14 @@ def main():
logging.info(f"device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ # Load tokens.txt here
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ # Load id of the token and the vocab size
# is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
@@ -183,6 +186,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
index 5712da25e..aeed58dec 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py
@@ -218,10 +218,9 @@ def export_decoder_model_jit_trace(
decoder_filename:
The filename to save the exported model.
"""
- y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
- need_pad = torch.tensor([False])
-
- traced_model = torch.jit.trace(decoder_model, (y, need_pad))
+ # TODO(fangjun): Change the function name since we are actually using
+ # torch.jit.script instead of torch.jit.trace
+ traced_model = torch.jit.script(decoder_model)
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
index ec2c9d580..e42a5c6ef 100755
--- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
+++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py
@@ -159,6 +159,7 @@ def main():
# Load id of the token and the vocab size
params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
index 03847b449..b961611f7 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py
@@ -91,7 +91,7 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
- embedding_out = self.embedding(y)
+ embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
index 59a7eb589..67041012d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
@@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@@ -87,7 +87,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
@@ -113,7 +113,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens data/lang_bpe_500/tokens.txt \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index bc33dd160..0f6190a41 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -23,7 +23,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
- --lang-dir ./data/lang_char \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 30 \
--avg 24 \
--use-averaged-model True
@@ -50,8 +50,9 @@ import argparse
import logging
from pathlib import Path
-import sentencepiece as spm
+import k2
import torch
+from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@@ -60,8 +61,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -118,13 +118,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="""The lang dir
- It contains language related input files such as
- "lexicon.txt"
- """,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -160,13 +157,14 @@ def main():
logging.info(f"device: {device}")
- bpe_model = params.lang_dir + "/bpe.model"
- sp = spm.SentencePieceProcessor()
- sp.load(bpe_model)
+ # Load tokens.txt here
+ token_table = k2.SymbolTable.from_file(params.tokens)
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ # Load id of the token and the vocab size
+ # is defined in local/train_bpe_model.py
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1 # +1 for
logging.info(params)
@@ -256,6 +254,7 @@ def main():
model.eval()
if params.jit:
+ convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
new file mode 120000
index 000000000..b82e115fc
--- /dev/null
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
\ No newline at end of file
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py
new file mode 120000
index 000000000..db93d155b
--- /dev/null
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
index 5d25daf5e..2f6ef488e 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py
@@ -24,7 +24,7 @@ Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2 \
--jit 1
@@ -47,7 +47,7 @@ for how to use them.
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2 \
--jit-trace 1
@@ -63,7 +63,7 @@ Check ./jit_pretrained.py for usage.
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 10 \
--avg 2
@@ -91,14 +91,14 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -133,10 +133,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -313,10 +313,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
index cb541070e..5ff1f4a3b 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
@@ -20,7 +20,7 @@
Usage for offline:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 4 \
--avg 1
@@ -28,7 +28,7 @@ It will generate a file exp_dir/pretrained.pt for offline ASR.
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 4 \
--avg 1 \
--jit True
@@ -38,7 +38,7 @@ It will generate a file exp_dir/cpu_jit.pt for offline ASR.
Usage for streaming:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 7 \
--avg 1
@@ -46,7 +46,7 @@ It will generate a file exp_dir/pretrained.pt for streaming ASR.
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 7 \
--avg 1 \
--jit True
@@ -73,13 +73,13 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -114,10 +114,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -152,10 +152,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
From 1c30847947f9d8b1416ef3e70408c07eab807f3d Mon Sep 17 00:00:00 2001
From: Yuekai Zhang
Date: Sat, 27 Jan 2024 00:32:30 +0800
Subject: [PATCH 043/123] Whisper Fine-tuning Recipe on Aishell1 (#1466)
* add decode seamlessm4t
* add requirements
* add decoding with avg model
* add token files
* add custom tokenizer
* support deepspeed to finetune large model
* support large-v3
* add model saving
* using monkey patch to replace models
* add manifest dir option
---
egs/aishell/ASR/README.md | 7 +
egs/aishell/ASR/RESULTS.md | 67 +-
.../ASR/local/compute_fbank_aishell.py | 45 +-
egs/aishell/ASR/prepare.sh | 13 +
egs/aishell/ASR/whisper/asr_datamodule.py | 1 +
egs/aishell/ASR/whisper/decode.py | 503 ++++++++++
egs/aishell/ASR/whisper/ds_config_zero1.json | 38 +
egs/aishell/ASR/whisper/label_smoothing.py | 1 +
egs/aishell/ASR/whisper/optim.py | 1 +
egs/aishell/ASR/whisper/requirements.txt | 10 +
egs/aishell/ASR/whisper/train.py | 927 ++++++++++++++++++
.../whisper_encoder_forward_monkey_patch.py | 29 +
.../ASR/local/compute_fbank_musan.py | 59 +-
icefall/dist.py | 2 +-
14 files changed, 1682 insertions(+), 21 deletions(-)
create mode 120000 egs/aishell/ASR/whisper/asr_datamodule.py
create mode 100755 egs/aishell/ASR/whisper/decode.py
create mode 100644 egs/aishell/ASR/whisper/ds_config_zero1.json
create mode 120000 egs/aishell/ASR/whisper/label_smoothing.py
create mode 120000 egs/aishell/ASR/whisper/optim.py
create mode 100755 egs/aishell/ASR/whisper/requirements.txt
create mode 100755 egs/aishell/ASR/whisper/train.py
create mode 100644 egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md
index 176f065e5..b54719162 100644
--- a/egs/aishell/ASR/README.md
+++ b/egs/aishell/ASR/README.md
@@ -24,3 +24,10 @@ The following table lists the differences among them.
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.
+
+# Whisper
+
+Recipe to finetune large pretrained models
+| | Encoder | Decoder | Comment |
+|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
+| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed
diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index ff9504274..46d712fb2 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -1,5 +1,63 @@
## Results
+### Aishell training results (Fine-tuning Pretrained Models)
+#### Whisper
+[./whisper](./whisper)
+##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3
+
+| | test (before fine-tuning) | test (after fine-tuning) | comment |
+|------------------------|------|------|-----------------------------------------|
+| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp |
+| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 |
+| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 |
+
+Command for training is:
+```bash
+pip install -r whisper/requirements.txt
+
+./prepare.sh --stage 30 --stop_stage 30
+
+#fine-tuning with deepspeed zero stage 1
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --deepspeed \
+ --deepspeed_config ./whisper/ds_config_zero1.json
+
+# fine-tuning with ddp
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_medium \
+ --base-lr 1e-5 \
+ --model-name medium
+```
+
+Command for decoding using fine-tuned models:
+```bash
+git lfs install
+git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
+ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch 999 --avg 1 \
+ --beam-size 10 --max-duration 50
+```
+Command for decoding using pretrained models (before fine-tuning):
+```bash
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch -1 --avg 1 \
+ --remove-whisper-encoder-input-length-restriction False \
+ --beam-size 10 --max-duration 50
+```
+Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
+are available at
+
+
### Aishell training result (Stateless Transducer)
#### Zipformer (Byte-level BPE)
@@ -71,7 +129,7 @@ It's reworked Zipformer with Pruned RNNT loss.
Command for training is:
```bash
-./prepare.sh
+./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1"
@@ -136,7 +194,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
- --max-duration 1200
+ --max-duration 1200
```
Command for decoding is:
@@ -186,7 +244,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
- --max-duration 800
+ --max-duration 800
```
Command for decoding is:
@@ -202,7 +260,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
- --encoder-unmasked-dim 192,192,256,320,256,192
+ --encoder-unmasked-dim 192,192,256,320,256,192
done
```
@@ -755,7 +813,6 @@ python3 ./transducer_stateless/decode.py \
--max-sym-per-frame 3
```
-### Aishell training results (Transducer-stateless)
#### 2022-02-18
(Pingfeng Luo) : The tensorboard log for training is available at
And pretrained model is available at
diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py
index c7000da1c..3c48f0aa1 100755
--- a/egs/aishell/ASR/local/compute_fbank_aishell.py
+++ b/egs/aishell/ASR/local/compute_fbank_aishell.py
@@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ WhisperFbank,
+ WhisperFbankConfig,
+)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@@ -42,9 +49,14 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
+def compute_fbank_aishell(
+ num_mel_bins: int = 80,
+ perturb_speed: bool = False,
+ whisper_fbank: bool = False,
+ output_dir: str = "data/fbank",
+):
src_dir = Path("data/manifests")
- output_dir = Path("data/fbank")
+ output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
dataset_parts = (
@@ -68,8 +80,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()),
dataset_parts,
)
-
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@@ -82,7 +98,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
- logging.info(f"Doing speed perturb")
+ logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@@ -111,6 +127,18 @@ def get_args():
default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
)
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/fbank",
+ help="Output directory. Default: data/fbank.",
+ )
return parser.parse_args()
@@ -121,5 +149,8 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell(
- num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed
+ num_mel_bins=args.num_mel_bins,
+ perturb_speed=args.perturb_speed,
+ whisper_fbank=args.whisper_fbank,
+ output_dir=args.output_dir,
)
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index 9f73a2073..b7be89bc8 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -376,3 +376,16 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
--vocab-size 4336 \
--master-port 12345
fi
+
+# whisper large-v3 using 128 mel bins, others using 80 mel bins
+whisper_mel_bins=80
+output_dir=data/fbank_whisper
+if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
+ log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
+ if [ ! -f $output_dir/.aishell.whisper.done ]; then
+ mkdir -p $output_dir
+ ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
+ ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
+ touch $output_dir/.aishell.whisper.done
+ fi
+fi
diff --git a/egs/aishell/ASR/whisper/asr_datamodule.py b/egs/aishell/ASR/whisper/asr_datamodule.py
new file mode 120000
index 000000000..fa1b8cca3
--- /dev/null
+++ b/egs/aishell/ASR/whisper/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py
new file mode 100755
index 000000000..7f841dcb7
--- /dev/null
+++ b/egs/aishell/ASR/whisper/decode.py
@@ -0,0 +1,503 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
+# Fangjun Kuang,
+# Wei Kang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+# Command for decoding using fine-tuned models:
+git lfs install
+git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
+ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch 999 --avg 1 \
+ --manifest-dir data/fbank_whisper \
+ --beam-size 10 --max-duration 50
+
+# Command for decoding using pretrained models (before fine-tuning):
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch -1 --avg 1 \
+ --manifest-dir data/fbank_whisper \
+ --remove-whisper-encoder-input-length-restriction False \
+ --beam-size 10 --max-duration 50
+
+"""
+
+import argparse
+import logging
+import re
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+import whisper
+from asr_datamodule import AishellAsrDataModule
+from tn.chinese.normalizer import Normalizer
+from whisper.normalizers import BasicTextNormalizer
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+from zhconv import convert
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
+from icefall.env import get_env_info
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+
+def average_checkpoints(
+ filenames: List[Path], device: torch.device = torch.device("cpu")
+) -> dict:
+ """Average a list of checkpoints.
+ The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
+
+ Args:
+ filenames:
+ Filenames of the checkpoints to be averaged. We assume all
+ checkpoints are saved by :func:`save_checkpoint`.
+ device:
+ Move checkpoints to this device before averaging.
+ Returns:
+ Return a dict (i.e., state_dict) which is the average of all
+ model state dicts contained in the checkpoints.
+ """
+ n = len(filenames)
+
+ if "model" in torch.load(filenames[0], map_location=device):
+ avg = torch.load(filenames[0], map_location=device)["model"]
+ else:
+ avg = torch.load(filenames[0], map_location=device)
+
+ # Identify shared parameters. Two parameters are said to be shared
+ # if they have the same data_ptr
+ uniqued: Dict[int, str] = dict()
+
+ for k, v in avg.items():
+ v_data_ptr = v.data_ptr()
+ if v_data_ptr in uniqued:
+ continue
+ uniqued[v_data_ptr] = k
+
+ uniqued_names = list(uniqued.values())
+
+ for i in range(1, n):
+ if "model" in torch.load(filenames[i], map_location=device):
+ state_dict = torch.load(filenames[i], map_location=device)["model"]
+ else:
+ state_dict = torch.load(filenames[i], map_location=device)
+ for k in uniqued_names:
+ avg[k] += state_dict[k]
+
+ for k in uniqued_names:
+ if avg[k].is_floating_point():
+ avg[k] /= n
+ else:
+ avg[k] //= n
+
+ return avg
+
+
+def remove_punctuation(text: str or List[str]):
+ """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
+
+ Args:
+ text: It can be a string or a list of strings.
+ Returns:
+ Return a string or a list of strings without any punctuation.
+ """
+ punctuation = "!,.;:?、!,。;:?《》 "
+ if isinstance(text, str):
+ text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
+ return text
+ elif isinstance(text, list):
+ result_text = []
+ for t in text:
+ t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
+ result_text.append(t)
+ return result_text
+ else:
+ raise Exception(f"Not support type {type(text)}")
+
+
+def to_simple(text: str or List[str]):
+ """Convert traditional Chinese to simplified Chinese.
+ Args:
+ text: It can be a string or a list of strings.
+ Returns:
+ Return a string or a list of strings converted to simplified Chinese.
+ """
+ if isinstance(text, str):
+ text = convert(text, "zh-cn")
+ return text
+ elif isinstance(text, list):
+ result_text = []
+ for t in text:
+ t = convert(t, "zh-cn")
+ result_text.append(t)
+ return result_text
+ else:
+ raise Exception(f"Not support type{type(text)}")
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=-1,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="beam-search",
+ help="""Decoding method.
+ Supported values are:
+ - beam-search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=1,
+ help="beam size for beam search decoding",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="whisper/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="large-v2",
+ choices=["large-v2", "large-v3", "medium", "small", "tiny"],
+ help="""The model name to use.
+ """,
+ )
+
+ parser.add_argument(
+ "--remove-whisper-encoder-input-length-restriction",
+ type=str2bool,
+ default=True,
+ help="replace whisper encoder forward method to remove input length restriction",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "env_info": get_env_info(),
+ }
+ )
+ return params
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ batch: dict,
+) -> Dict[str, List[List[int]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: "beam-search"
+ - value: A list of lists. Each sublist is a list of token IDs.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
+ Returns:
+ Return a dict, whose key may be "beam-search".
+ """
+ dtype = torch.float16
+ device = torch.device("cuda")
+
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+ feature = feature.to(device, dtype=dtype).transpose(1, 2)
+ if not params.remove_whisper_encoder_input_length_restriction:
+ T = 3000
+ if feature.shape[2] < T:
+ feature = torch.cat(
+ [
+ feature,
+ torch.zeros(
+ feature.shape[0], feature.shape[1], T - feature.shape[2]
+ ).to(device, dtype=dtype),
+ ],
+ 2,
+ )
+
+ supervisions = batch["supervisions"]
+ feature_len = supervisions["num_frames"]
+ feature_len = feature_len.to(device, dtype=dtype)
+ results = model.decode(feature, params.decoding_options)
+ hyps = [result.text for result in results]
+
+ hyps = remove_punctuation(hyps)
+ hyps = to_simple(hyps)
+ hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
+
+ return {"beam-search": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ The dataloader.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ Returns:
+ Return a dict, whose key may be "beam-search".
+ """
+ results = []
+
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ batch=batch,
+ )
+
+ for lm_scale, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[lm_scale].extend(this_batch)
+
+ num_cuts += len(batch["supervisions"]["text"])
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+
+ enable_log = True
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ if enable_log:
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ # we compute CER for aishell dataset.
+ results_char = []
+ for res in results:
+ results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ )
+ test_set_wers[key] = wer
+
+ if enable_log:
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tCER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+ setup_logger(
+ f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
+ )
+
+ options = whisper.DecodingOptions(
+ task="transcribe",
+ language="zh",
+ without_timestamps=True,
+ beam_size=params.beam_size,
+ )
+ params.decoding_options = options
+ params.cleaner = BasicTextNormalizer()
+ params.normalizer = Normalizer()
+
+ logging.info("Decoding started")
+ logging.info(params)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+
+ logging.info(f"device: {device}")
+
+ if params.remove_whisper_encoder_input_length_restriction:
+ replace_whisper_encoder_forward()
+ model = whisper.load_model(params.model_name, "cpu")
+ if params.epoch > 0:
+ if params.avg > 1:
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ if "model" not in checkpoint:
+ # deepspeed converted checkpoint only contains model state_dict
+ filenames = [
+ f"{params.exp_dir}/epoch-{epoch}.pt"
+ for epoch in range(start, params.epoch + 1)
+ ]
+ model.load_state_dict(average_checkpoints(filenames))
+ else:
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ # save checkpoints
+ filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save(model.state_dict(), filename)
+ else:
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ if "model" not in checkpoint:
+ model.load_state_dict(checkpoint, strict=True)
+ else:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ model.to(device)
+ model.eval()
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ aishell = AishellAsrDataModule(args)
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
+ test_dl = aishell.test_dataloaders(aishell.test_cuts())
+ test_sets = ["valid", "test"]
+ test_dls = [valid_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ )
+
+ save_results(params=params, test_set_name=test_set, results_dict=results_dict)
+
+ logging.info("Done!")
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json
new file mode 100644
index 000000000..bf8cc0452
--- /dev/null
+++ b/egs/aishell/ASR/whisper/ds_config_zero1.json
@@ -0,0 +1,38 @@
+{
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 100,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 0.01
+ },
+ "zero_optimization": {
+ "stage": 1,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-5
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": 0,
+ "warmup_max_lr": 1e-5,
+ "warmup_num_steps": 100
+ }
+ },
+ "gradient_accumulation_steps": 1,
+ "gradient_clipping": 5,
+ "steps_per_print": 50,
+ "train_micro_batch_size_per_gpu": 1,
+ "wall_clock_breakdown": false
+}
diff --git a/egs/aishell/ASR/whisper/label_smoothing.py b/egs/aishell/ASR/whisper/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/aishell/ASR/whisper/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/aishell/ASR/whisper/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt
new file mode 100755
index 000000000..0708f2344
--- /dev/null
+++ b/egs/aishell/ASR/whisper/requirements.txt
@@ -0,0 +1,10 @@
+k2
+kaldialign
+git+https://github.com/lhotse-speech/lhotse
+sentencepiece
+tensorboard
+librosa
+git+https://github.com/yuekaizhang/whisper.git
+zhconv
+WeTextProcessing
+deepspeed
diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py
new file mode 100755
index 000000000..d16793eb2
--- /dev/null
+++ b/egs/aishell/ASR/whisper/train.py
@@ -0,0 +1,927 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+#fine-tuning with deepspeed zero stage 1
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --manifest-dir data/fbank_whisper \
+ --deepspeed \
+ --deepspeed_config ./whisper/ds_config_zero1.json
+
+# fine-tuning with ddp
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_medium \
+ --manifest-dir data/fbank_whisper \
+ --base-lr 1e-5 \
+ --model-name medium
+"""
+
+
+import argparse
+import copy
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import deepspeed
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import whisper
+from asr_datamodule import AishellAsrDataModule
+from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+from label_smoothing import LabelSmoothingLoss
+from lhotse import CutSet, load_manifest
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.functional import pad as pad_tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import update_averaged_model
+from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ filter_uneven_sized_batch,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=10,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="large-v2",
+ choices=["large-v2", "large-v3", "medium", "small", "tiny"],
+ help="""The model name to use.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=1e-5, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser = deepspeed.add_config_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - frame_shift_ms: The frame shift in milliseconds.
+ - allowed_excess_duration_ratio: The allowed excess duration ratio.
+ - best_train_loss: The best training loss so far.
+ - best_valid_loss: The best validation loss so far.
+ - best_train_epoch: The epoch where the best training loss is achieved.
+ - best_valid_epoch: The epoch where the best validation loss is achieved.
+ - batch_idx_train: The batch index of the current batch.
+ - log_interval: Log training stats every `log_interval` batches.
+ - reset_interval: Reset the stats every `reset_interval` batches.
+ - valid_interval: Run validation every `valid_interval` batches.
+ - env_info: The environment information.
+ """
+ params = AttributeDict(
+ {
+ "frame_shift_ms": 10.0,
+ "subsampling_factor": 2,
+ "allowed_excess_duration_ratio": 0.1,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 5000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute the loss for the given batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ tokenizer:
+ The tokenizer used to encode the text.
+ model:
+ The model for training.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ Whether it is training.
+ Returns:
+ Return a tuple of two elements. The first element is the loss tensor.
+ """
+ # For the uneven-sized batch, the total duration after padding would possibly
+ # cause OOM. Hence, for each batch, which is sorted descendingly by length,
+ # we simply drop the last few shortest samples, so that the retained total frames
+ # (after padding) would not exceed `allowed_max_frames`:
+ # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
+ # where `max_frames = max_duration * 1000 // frame_shift_ms`.
+ # We set allowed_excess_duration_ratio=0.1.
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+
+ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
+ padding_size = max(tensor.shape[0] for tensor in tensors)
+ dims = len(tensors[0].shape)
+ padded_tensors = []
+ for tensor in tensors:
+ padding = [0] * 2 * dims
+ padding[-1] = padding_size - tensor.shape[0]
+ padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
+ return torch.stack([tensor for tensor in padded_tensors], dim=0)
+
+ max_frames = params.max_duration * 1000 // params.frame_shift_ms
+ allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
+ batch = filter_uneven_sized_batch(batch, allowed_max_frames)
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ feature = feature.transpose(1, 2) # (N, C, T)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+
+ texts = batch["supervisions"]["text"]
+ # remove spaces in texts
+ texts = [text.replace(" ", "") for text in texts]
+
+ text_tokens_list = [
+ list(tokenizer.sot_sequence_including_notimestamps)
+ + tokenizer.encode(text)
+ + [tokenizer.eot]
+ for text in texts
+ ]
+ # convert it to torch tensor
+ text_tokens_list = [
+ torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
+ ]
+
+ # 50256 is the index of for all whisper models
+ prev_outputs_tokens = _batch_tensors(
+ [tokens[:-1] for tokens in text_tokens_list], pad_value=50256
+ )
+ target_tokens = _batch_tensors(
+ [tokens[1:] for tokens in text_tokens_list], pad_value=50256
+ )
+ target_lengths = torch.LongTensor(
+ [tokens.shape[0] - 1 for tokens in text_tokens_list]
+ )
+
+ decoder_criterion = LabelSmoothingLoss(
+ ignore_index=50256, label_smoothing=0.1, reduction="sum"
+ )
+
+ # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
+ ignore_prefix_size = 3
+ with torch.set_grad_enabled(is_training):
+ encoder_out = model.encoder(feature)
+ text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
+ text_logits = text_logits[:, ignore_prefix_size:, :]
+ target_tokens = target_tokens[:, ignore_prefix_size:]
+ loss = decoder_criterion(text_logits, target_tokens.to(device))
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ if params.deepspeed:
+ # deepspeed's backward() is different from torch's backward()
+ # in that it does not accept a loss tensor as input.
+ # It computes the loss internally.
+ model.backward(loss)
+ model.step()
+ else:
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ and not params.deepspeed
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+ if batch_idx % params.log_interval == 0:
+ try:
+ cur_lr = scheduler.get_last_lr()[0]
+ except: # noqa
+ cur_lr = 0.0
+ cur_grad_scale = (
+ scaler._scale.item()
+ if (params.use_fp16 and not params.deepspeed)
+ else 1.0
+ )
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (
+ f"grad_scale: {scaler._scale.item()}"
+ if (params.use_fp16 and not params.deepspeed)
+ else ""
+ )
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ replace_whisper_encoder_forward()
+ model = whisper.load_model(params.model_name, "cpu")
+ del model.alignment_heads
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ tokenizer = whisper.tokenizer.get_tokenizer(
+ model.is_multilingual,
+ num_languages=model.num_languages,
+ language="zh",
+ task="transcribe",
+ )
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ else:
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+ model.to(device)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if world_size > 1:
+ if params.deepspeed:
+ logging.info("Using DeepSpeed")
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=params, model=model, model_parameters=model.parameters()
+ )
+ else:
+ logging.info("Using DDP")
+ setup_dist(use_ddp_launch=True)
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(aishell.train_cuts())
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ logging.info(f"start training from epoch {params.start_epoch}")
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ if not params.deepspeed:
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if params.deepspeed:
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"epoch-{params.cur_epoch}",
+ client_state={},
+ )
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
+ tag=f"epoch-{params.cur_epoch}",
+ )
+ else:
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1 and not params.deepspeed:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ run(rank=rank, world_size=world_size, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
new file mode 100644
index 000000000..5bfbdce3b
--- /dev/null
+++ b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
@@ -0,0 +1,29 @@
+import torch
+import torch.nn.functional as F
+import whisper
+
+
+def forward(self, x: torch.Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+def replace_whisper_encoder_forward():
+ """
+ This function monkey patches the forward method of the whisper encoder.
+ To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
+ """
+ whisper.model.AudioEncoder.forward = forward
diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py
index 62036467e..d7781687f 100755
--- a/egs/librispeech/ASR/local/compute_fbank_musan.py
+++ b/egs/librispeech/ASR/local/compute_fbank_musan.py
@@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
-
+import argparse
import logging
import os
from pathlib import Path
import torch
-from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ MonoCut,
+ WhisperFbank,
+ WhisperFbankConfig,
+ combine,
+)
from lhotse.recipes.utils import read_manifests_if_cached
-from icefall.utils import get_executor
+from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@@ -45,11 +54,12 @@ def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
-def compute_fbank_musan():
+def compute_fbank_musan(
+ num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank"
+):
src_dir = Path("data/manifests")
- output_dir = Path("data/fbank")
+ output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
- num_mel_bins = 80
dataset_parts = (
"music",
@@ -81,7 +91,12 @@ def compute_fbank_musan():
logging.info("Extracting features for Musan")
- extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+ if whisper_fbank:
+ extractor = WhisperFbank(
+ WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ )
+ else:
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
@@ -102,8 +117,36 @@ def compute_fbank_musan():
musan_cuts.to_file(musan_cuts_path)
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--num-mel-bins",
+ type=int,
+ default=80,
+ help="""The number of mel bins for Fbank""",
+ )
+ parser.add_argument(
+ "--whisper-fbank",
+ type=str2bool,
+ default=False,
+ help="Use WhisperFbank instead of Fbank. Default: False.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="data/fbank",
+ help="Output directory. Default: data/fbank.",
+ )
+ return parser.parse_args()
+
+
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
- compute_fbank_musan()
+ args = get_args()
+ compute_fbank_musan(
+ num_mel_bins=args.num_mel_bins,
+ whisper_fbank=args.whisper_fbank,
+ output_dir=args.output_dir,
+ )
diff --git a/icefall/dist.py b/icefall/dist.py
index 922f31a2f..ee76e994a 100644
--- a/icefall/dist.py
+++ b/icefall/dist.py
@@ -22,7 +22,7 @@ from torch import distributed as dist
def setup_dist(
- rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None
+ rank=None, world_size=None, master_port=None, use_ddp_launch=False, master_addr=None
):
"""
rank and world_size are used only if use_ddp_launch is False.
From 37b975cac9ac237b71e957dd0770b091124fcb60 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Sat, 27 Jan 2024 06:41:56 +0800
Subject: [PATCH 044/123] fixed a CI test for `wenetspeech` (#1476)
* Comply to issue #1149
https://github.com/k2-fsa/icefall/issues/1149
---
...enetspeech-pruned-transducer-stateless2.sh | 6 ++---
.../pruned_transducer_stateless2/export.py | 19 +++++++-------
.../pruned_transducer_stateless3/export.py | 19 +++++++-------
.../export-onnx.py | 21 +++++++---------
.../ASR/transducer_stateless/export.py | 19 +++++++-------
.../transducer_stateless_modified-2/export.py | 18 ++++++-------
.../transducer_stateless_modified/export.py | 19 +++++++-------
.../pruned_transducer_stateless5/export.py | 20 +++++++--------
.../pruned_transducer_stateless5/export.py | 19 ++++++--------
.../pruned_transducer_stateless2/export.py | 19 +++++++-------
.../pruned_transducer_stateless7/export.py | 23 ++++++++---------
.../export-onnx-zh.py | 18 ++++++-------
egs/swbd/ASR/conformer_ctc/export.py | 19 +++++++-------
egs/tedlium3/ASR/conformer_ctc2/export.py | 18 ++++++-------
.../export-onnx.py | 25 ++++++++-----------
.../export-onnx-streaming.py | 19 +++++++-------
.../export-onnx.py | 18 ++++++-------
17 files changed, 150 insertions(+), 169 deletions(-)
diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
index a3a2d3080..981b74b76 100755
--- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
+++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh
@@ -30,7 +30,7 @@ log "Test exporting to ONNX format"
./pruned_transducer_stateless2/export-onnx.py \
--exp-dir $repo/exp \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1
@@ -38,14 +38,14 @@ log "Export to torchscript model"
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
./pruned_transducer_stateless2/export.py \
--exp-dir $repo/exp \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--jit-trace 1
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
index 2ce5cfe69..c2dc0d5f3 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py
@@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -106,10 +106,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -136,10 +136,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
index 723414167..2248c7a08 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py
@@ -47,6 +47,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@@ -57,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -123,10 +123,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -153,10 +153,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
params.datatang_prob = 0
logging.info(params)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
index 39d988cd0..4981fb71a 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
@@ -49,14 +49,14 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
+import k2
import onnx
-import sentencepiece as spm
import torch
import torch.nn as nn
from decoder2 import Decoder
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
-from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
@@ -65,8 +65,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import setup_logger, str2bool
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -123,12 +122,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- help="""The lang dir
- It contains language related input files such as
- "lexicon.txt"
- """,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -404,9 +401,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index 01de5d772..bfd0ecb0c 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -23,7 +23,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
- --lang-dir data/lang_char \
+ --tokens data/lang_char/tokens.txt \
--epoch 20 \
--avg 10
@@ -47,6 +47,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from conformer import Conformer
@@ -56,8 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
-from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -92,10 +92,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -192,10 +192,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
index c1081c32b..4f2c71d18 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py
@@ -46,6 +46,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from conformer import Conformer
@@ -56,7 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -99,10 +100,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -190,10 +191,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py
index 3e14ad69c..487748947 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/export.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/export.py
@@ -46,6 +46,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
import torch.nn as nn
from conformer import Conformer
@@ -55,8 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
-from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -99,10 +99,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help="The lang dir",
+ "--tokens",
+ type=str,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -190,10 +190,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
index 8a5be94d0..c92c7ab83 100755
--- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py
@@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
- --lang-dir data/lang_char
+ --tokens ./data/lang_char/tokens.txt \
--epoch 25 \
--avg 5
@@ -48,6 +48,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
@@ -57,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -115,10 +115,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -154,10 +154,10 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = lexicon.token_table[""]
- params.unk_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.unk_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
index bf9856c60..246820833 100755
--- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py
@@ -48,6 +48,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
@@ -57,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -115,13 +115,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="""The lang dir
- It contains language related input files such as
- "lexicon.txt"
- """,
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -157,9 +154,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = lexicon.token_table[""]
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
index 8e5cc6075..5dc73c52b 100644
--- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
+++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py
@@ -20,7 +20,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
- --lang-dir data/lang_char \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 29 \
--avg 18
@@ -45,12 +45,12 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -85,10 +85,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -122,10 +122,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
index 23a88dd29..8bafaef44 100755
--- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py
@@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
+ --tokens ./data/lang_char/tokens.txt \
--epoch 20 \
--avg 10
@@ -86,9 +86,8 @@ import argparse
import logging
from pathlib import Path
-import sentencepiece as spm
+import k2
import torch
-import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@@ -98,8 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import str2bool
+from icefall.utils import num_tokens, str2bool
def get_parser():
@@ -156,10 +154,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -199,10 +197,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
-
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
index 2a52e2eec..1ce770128 100755
--- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
+++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py
@@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx-zh.py \
- --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \
+ --tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \
--use-averaged-model 1 \
--epoch 11 \
--avg 1 \
@@ -55,6 +55,7 @@ import logging
from pathlib import Path
from typing import Dict, Optional, Tuple
+import k2
import onnx
import torch
import torch.nn as nn
@@ -70,8 +71,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import setup_logger, str2bool
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -128,10 +128,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -441,9 +441,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py
index 1bb6277ad..44b2e95d6 100755
--- a/egs/swbd/ASR/conformer_ctc/export.py
+++ b/egs/swbd/ASR/conformer_ctc/export.py
@@ -23,12 +23,12 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
-from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@@ -63,11 +63,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_bpe_500",
- help="""It contains language related input files such as "lexicon.txt"
- """,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -105,9 +104,9 @@ def main():
logging.info(params)
- lexicon = Lexicon(params.lang_dir)
- max_token_id = max(lexicon.tokens)
- num_classes = max_token_id + 1 # +1 for the blank
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
device = torch.device("cpu")
if torch.cuda.is_available():
@@ -119,7 +118,7 @@ def main():
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
- num_classes=num_classes,
+ num_classes=params.vocab_size,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py
index 009bea230..b5bf911c2 100755
--- a/egs/tedlium3/ASR/conformer_ctc2/export.py
+++ b/egs/tedlium3/ASR/conformer_ctc2/export.py
@@ -45,6 +45,7 @@ import argparse
import logging
from pathlib import Path
+import k2
import torch
from conformer import Conformer
from scaling_converter import convert_scaled_to_non_scaled
@@ -56,8 +57,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import AttributeDict, str2bool
+from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser() -> argparse.ArgumentParser:
@@ -118,10 +118,10 @@ def get_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_bpe_500",
- help="The lang dir",
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt.",
)
parser.add_argument(
@@ -166,9 +166,9 @@ def main():
params = get_params()
params.update(vars(args))
- lexicon = Lexicon(params.lang_dir)
- max_token_id = max(lexicon.tokens)
- num_classes = max_token_id + 1 # +1 for the blank
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
device = torch.device("cpu")
if torch.cuda.is_available():
@@ -182,7 +182,7 @@ def main():
model = Conformer(
num_features=params.feature_dim,
- num_classes=num_classes,
+ num_classes=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.dim_model,
nhead=params.nhead,
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py
index 140b1d37f..8aea79fe3 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py
@@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless2/export-onnx.py \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp
@@ -48,6 +48,7 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
+import k2
import onnx
import torch
import torch.nn as nn
@@ -57,14 +58,8 @@ from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
-from icefall.checkpoint import (
- average_checkpoints,
- average_checkpoints_with_averaged_model,
- find_checkpoints,
- load_checkpoint,
-)
-from icefall.lexicon import Lexicon
-from icefall.utils import setup_logger, str2bool
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -110,10 +105,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -397,9 +392,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py
index 921766ad4..30068d01a 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py
@@ -58,13 +58,13 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
+import k2
import onnx
-from icefall.lexicon import Lexicon
import torch
import torch.nn as nn
from conformer import Conformer
-from onnxruntime.quantization import QuantType, quantize_dynamic
from decoder import Decoder
+from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@@ -74,7 +74,8 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.utils import setup_logger, str2bool
+from icefall.lexicon import Lexicon
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -131,10 +132,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -490,9 +491,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py
index 037c7adf1..1c9eb8648 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py
@@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx.py \
- --lang-dir $repo/data/lang_char \
+ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@@ -55,6 +55,7 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
+import k2
import onnx
import torch
import torch.nn as nn
@@ -70,8 +71,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
-from icefall.lexicon import Lexicon
-from icefall.utils import setup_logger, str2bool
+from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@@ -128,10 +128,10 @@ def get_parser():
)
parser.add_argument(
- "--lang-dir",
+ "--tokens",
type=str,
- default="data/lang_char",
- help="The lang dir",
+ default="data/lang_char/tokens.txt",
+ help="Path to the tokens.txt",
)
parser.add_argument(
@@ -417,9 +417,9 @@ def main():
logging.info(f"device: {device}")
- lexicon = Lexicon(params.lang_dir)
- params.blank_id = 0
- params.vocab_size = max(lexicon.tokens) + 1
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
From b07d5472c58eb11bc86cdcf787ecfdb634a55c73 Mon Sep 17 00:00:00 2001
From: Henry Li Xinyuan <57513260+HSTEHSTEHSTE@users.noreply.github.com>
Date: Wed, 31 Jan 2024 09:53:36 -0500
Subject: [PATCH 045/123] Implement recipe for Fluent Speech Commands dataset
(#1469)
---------
Signed-off-by: Xinyuan Li
---
egs/fluent_speech_commands/SLU/README.md | 9 +
.../SLU/local/compile_hlg.py | 136 ++++
.../SLU/local/compute_fbank_slu.py | 97 +++
.../SLU/local/generate_lexicon.py | 59 ++
.../SLU/local/prepare_lang.py | 371 +++++++++++
egs/fluent_speech_commands/SLU/prepare.sh | 103 +++
egs/fluent_speech_commands/SLU/shared | 1 +
.../SLU/transducer/__init__.py | 0
.../SLU/transducer/beam_search.py | 71 ++
.../SLU/transducer/conformer.py | 1 +
.../SLU/transducer/decode.py | 346 ++++++++++
.../SLU/transducer/decoder.py | 1 +
.../SLU/transducer/encoder_interface.py | 1 +
.../SLU/transducer/joiner.py | 1 +
.../SLU/transducer/model.py | 1 +
.../SLU/transducer/slu_datamodule.py | 289 ++++++++
.../SLU/transducer/subsampling.py | 1 +
.../SLU/transducer/test_conformer.py | 1 +
.../SLU/transducer/test_decoder.py | 1 +
.../SLU/transducer/test_joiner.py | 1 +
.../SLU/transducer/test_transducer.py | 1 +
.../SLU/transducer/train.py | 625 ++++++++++++++++++
.../SLU/transducer/transformer.py | 1 +
icefall/shared/make_kn_lm.py | 9 +-
24 files changed, 2124 insertions(+), 3 deletions(-)
create mode 100755 egs/fluent_speech_commands/SLU/README.md
create mode 100755 egs/fluent_speech_commands/SLU/local/compile_hlg.py
create mode 100755 egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py
create mode 100755 egs/fluent_speech_commands/SLU/local/generate_lexicon.py
create mode 100755 egs/fluent_speech_commands/SLU/local/prepare_lang.py
create mode 100755 egs/fluent_speech_commands/SLU/prepare.sh
create mode 120000 egs/fluent_speech_commands/SLU/shared
create mode 100755 egs/fluent_speech_commands/SLU/transducer/__init__.py
create mode 100755 egs/fluent_speech_commands/SLU/transducer/beam_search.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/conformer.py
create mode 100755 egs/fluent_speech_commands/SLU/transducer/decode.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/decoder.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/encoder_interface.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/joiner.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/model.py
create mode 100755 egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/subsampling.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_conformer.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_decoder.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_joiner.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_transducer.py
create mode 100755 egs/fluent_speech_commands/SLU/transducer/train.py
create mode 120000 egs/fluent_speech_commands/SLU/transducer/transformer.py
diff --git a/egs/fluent_speech_commands/SLU/README.md b/egs/fluent_speech_commands/SLU/README.md
new file mode 100755
index 000000000..a203a9bfb
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/README.md
@@ -0,0 +1,9 @@
+## Fluent Speech Commands recipe
+
+This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances.
+
+Dataset Paper link:
+
+cd icefall/egs/fluent_speech_commands/
+Training: python transducer/train.py
+Decoding: python transducer/decode.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/local/compile_hlg.py b/egs/fluent_speech_commands/SLU/local/compile_hlg.py
new file mode 100755
index 000000000..a7df8f966
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/compile_hlg.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python3
+
+"""
+This script takes as input lang_dir and generates HLG from
+
+ - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
+ - L, the lexicon, built from lang_dir/L_disambig.pt
+
+ Caution: We use a lexicon that contains disambiguation symbols
+
+ - G, the LM, built from data/lm/G.fst.txt
+
+The generated HLG is saved in $lang_dir/HLG.pt
+"""
+import argparse
+import logging
+from pathlib import Path
+
+import k2
+import torch
+
+from icefall.lexicon import Lexicon
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ """,
+ )
+
+ return parser.parse_args()
+
+
+def compile_HLG(lang_dir: str) -> k2.Fsa:
+ """
+ Args:
+ lang_dir:
+ The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
+
+ Return:
+ An FSA representing HLG.
+ """
+ lexicon = Lexicon(lang_dir)
+ max_token_id = max(lexicon.tokens)
+ logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
+ H = k2.ctc_topo(max_token_id)
+ L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
+
+ logging.info("Loading G.fst.txt")
+ with open(lang_dir / "G.fst.txt") as f:
+ G = k2.Fsa.from_openfst(f.read(), acceptor=False)
+
+ first_token_disambig_id = lexicon.token_table["#0"]
+ first_word_disambig_id = lexicon.word_table["#0"]
+
+ L = k2.arc_sort(L)
+ G = k2.arc_sort(G)
+
+ logging.info("Intersecting L and G")
+ LG = k2.compose(L, G)
+ logging.info(f"LG shape: {LG.shape}")
+
+ logging.info("Connecting LG")
+ LG = k2.connect(LG)
+ logging.info(f"LG shape after k2.connect: {LG.shape}")
+
+ logging.info(type(LG.aux_labels))
+ logging.info("Determinizing LG")
+
+ LG = k2.determinize(LG)
+ logging.info(type(LG.aux_labels))
+
+ logging.info("Connecting LG after k2.determinize")
+ LG = k2.connect(LG)
+
+ logging.info("Removing disambiguation symbols on LG")
+
+ # LG.labels[LG.labels >= first_token_disambig_id] = 0
+ # see https://github.com/k2-fsa/k2/pull/1140
+ labels = LG.labels
+ labels[labels >= first_token_disambig_id] = 0
+ LG.labels = labels
+
+ assert isinstance(LG.aux_labels, k2.RaggedTensor)
+ LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
+
+ LG = k2.remove_epsilon(LG)
+ logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
+
+ LG = k2.connect(LG)
+ LG.aux_labels = LG.aux_labels.remove_values_eq(0)
+
+ logging.info("Arc sorting LG")
+ LG = k2.arc_sort(LG)
+
+ logging.info("Composing H and LG")
+ # CAUTION: The name of the inner_labels is fixed
+ # to `tokens`. If you want to change it, please
+ # also change other places in icefall that are using
+ # it.
+ HLG = k2.compose(H, LG, inner_labels="tokens")
+
+ logging.info("Connecting LG")
+ HLG = k2.connect(HLG)
+
+ logging.info("Arc sorting LG")
+ HLG = k2.arc_sort(HLG)
+ logging.info(f"HLG.shape: {HLG.shape}")
+
+ return HLG
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+
+ if (lang_dir / "HLG.pt").is_file():
+ logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
+ return
+
+ logging.info(f"Processing {lang_dir}")
+
+ HLG = compile_HLG(lang_dir)
+ logging.info(f"Saving HLG.pt to {lang_dir}")
+ torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py
new file mode 100755
index 000000000..a51b7b47b
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+
+"""
+This file computes fbank features of the Fluent Speech Commands dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or it wastes a
+# lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_slu(manifest_dir, fbanks_dir):
+ src_dir = Path(manifest_dir)
+ output_dir = Path(fbanks_dir)
+
+ # This dataset is rather small, so we use only one job
+ num_jobs = min(1, os.cpu_count())
+ num_mel_bins = 23
+
+ dataset_parts = (
+ "train",
+ "valid",
+ "test",
+ )
+ prefix = "slu"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}"
+ if cuts_file.is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if "train" in partition:
+ cut_set = (
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ )
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 1, # use one job
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(cuts_file)
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("manifest_dir")
+parser.add_argument("fbanks_dir")
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ args = parser.parse_args()
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_slu(args.manifest_dir, args.fbanks_dir)
diff --git a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py
new file mode 100755
index 000000000..6263e062f
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py
@@ -0,0 +1,59 @@
+import argparse
+
+import pandas
+from tqdm import tqdm
+
+
+def generate_lexicon(corpus_dir, lm_dir):
+ data = pandas.read_csv(
+ str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
+ )
+ vocab_transcript = set()
+ vocab_frames = set()
+ transcripts = data["transcription"].tolist()
+ frames = list(
+ i
+ for i in zip(
+ data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
+ )
+ )
+
+ for transcript in tqdm(transcripts):
+ for word in transcript.split():
+ vocab_transcript.add(word)
+
+ for frame in tqdm(frames):
+ for word in frame:
+ vocab_frames.add("_".join(word.split()))
+
+ with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
+ lexicon_transcript_file.write(" 1" + "\n")
+ lexicon_transcript_file.write(" 2" + "\n")
+ lexicon_transcript_file.write(" 0" + "\n")
+ id = 3
+ for vocab in vocab_transcript:
+ lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
+ id += 1
+
+ with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
+ lexicon_frames_file.write(" 1" + "\n")
+ lexicon_frames_file.write(" 2" + "\n")
+ lexicon_frames_file.write(" 0" + "\n")
+ id = 3
+ for vocab in vocab_frames:
+ lexicon_frames_file.write(vocab + " " + str(id) + "\n")
+ id += 1
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("corpus_dir")
+parser.add_argument("lm_dir")
+
+
+def main():
+ args = parser.parse_args()
+
+ generate_lexicon(args.corpus_dir, args.lm_dir)
+
+
+main()
diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py
new file mode 100755
index 000000000..2a71dcf81
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/local/prepare_lang.py
@@ -0,0 +1,371 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
+
+"""
+This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
+consisting of words and tokens (i.e., phones) and does the following:
+
+1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
+
+2. Generate tokens.txt, the token table mapping a token to a unique integer.
+
+3. Generate words.txt, the word table mapping a word to a unique integer.
+
+4. Generate L.pt, in k2 format. It can be loaded by
+
+ d = torch.load("L.pt")
+ lexicon = k2.Fsa.from_dict(d)
+
+5. Generate L_disambig.pt, in k2 format.
+"""
+import argparse
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import k2
+import torch
+
+from icefall.lexicon import read_lexicon, write_lexicon
+
+Lexicon = List[Tuple[str, List[str]]]
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_tokens(lexicon: Lexicon) -> List[str]:
+ """Get tokens from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique tokens.
+ """
+ ans = set()
+ for _, tokens in lexicon:
+ ans.update(tokens)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def get_words(lexicon: Lexicon) -> List[str]:
+ """Get words from a lexicon.
+
+ Args:
+ lexicon:
+ It is the return value of :func:`read_lexicon`.
+ Returns:
+ Return a list of unique words.
+ """
+ ans = set()
+ for word, _ in lexicon:
+ ans.add(word)
+ sorted_ans = sorted(list(ans))
+ return sorted_ans
+
+
+def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
+ """It adds pseudo-token disambiguation symbols #1, #2 and so on
+ at the ends of tokens to ensure that all pronunciations are different,
+ and that none is a prefix of another.
+
+ See also add_lex_disambig.pl from kaldi.
+
+ Args:
+ lexicon:
+ It is returned by :func:`read_lexicon`.
+ Returns:
+ Return a tuple with two elements:
+
+ - The output lexicon with disambiguation symbols
+ - The ID of the max disambiguation symbol that appears
+ in the lexicon
+ """
+
+ # (1) Work out the count of each token-sequence in the
+ # lexicon.
+ count = defaultdict(int)
+ for _, tokens in lexicon:
+ count[" ".join(tokens)] += 1
+
+ # (2) For each left sub-sequence of each token-sequence, note down
+ # that it exists (for identifying prefixes of longer strings).
+ issubseq = defaultdict(int)
+ for _, tokens in lexicon:
+ tokens = tokens.copy()
+ tokens.pop()
+ while tokens:
+ issubseq[" ".join(tokens)] = 1
+ tokens.pop()
+
+ # (3) For each entry in the lexicon:
+ # if the token sequence is unique and is not a
+ # prefix of another word, no disambig symbol.
+ # Else output #1, or #2, #3, ... if the same token-seq
+ # has already been assigned a disambig symbol.
+ ans = []
+
+ # We start with #1 since #0 has its own purpose
+ first_allowed_disambig = 1
+ max_disambig = first_allowed_disambig - 1
+ last_used_disambig_symbol_of = defaultdict(int)
+
+ for word, tokens in lexicon:
+ tokenseq = " ".join(tokens)
+ assert tokenseq != ""
+ if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
+ ans.append((word, tokens))
+ continue
+
+ cur_disambig = last_used_disambig_symbol_of[tokenseq]
+ if cur_disambig == 0:
+ cur_disambig = first_allowed_disambig
+ else:
+ cur_disambig += 1
+
+ if cur_disambig > max_disambig:
+ max_disambig = cur_disambig
+ last_used_disambig_symbol_of[tokenseq] = cur_disambig
+ tokenseq += f" #{cur_disambig}"
+ ans.append((word, tokenseq.split()))
+ return ans, max_disambig
+
+
+def generate_id_map(symbols: List[str]) -> Dict[str, int]:
+ """Generate ID maps, i.e., map a symbol to a unique ID.
+
+ Args:
+ symbols:
+ A list of unique symbols.
+ Returns:
+ A dict containing the mapping between symbols and IDs.
+ """
+ return {sym: i for i, sym in enumerate(symbols)}
+
+
+def add_self_loops(
+ arcs: List[List[Any]], disambig_token: int, disambig_word: int
+) -> List[List[Any]]:
+ """Adds self-loops to states of an FST to propagate disambiguation symbols
+ through it. They are added on each state with non-epsilon output symbols
+ on at least one arc out of the state.
+
+ See also fstaddselfloops.pl from Kaldi. One difference is that
+ Kaldi uses OpenFst style FSTs and it has multiple final states.
+ This function uses k2 style FSTs and it does not need to add self-loops
+ to the final state.
+
+ The input label of a self-loop is `disambig_token`, while the output
+ label is `disambig_word`.
+
+ Args:
+ arcs:
+ A list-of-list. The sublist contains
+ `[src_state, dest_state, label, aux_label, score]`
+ disambig_token:
+ It is the token ID of the symbol `#0`.
+ disambig_word:
+ It is the word ID of the symbol `#0`.
+
+ Return:
+ Return new `arcs` containing self-loops.
+ """
+ states_needs_self_loops = set()
+ for arc in arcs:
+ src, dst, ilabel, olabel, score = arc
+ if olabel != 0:
+ states_needs_self_loops.add(src)
+
+ ans = []
+ for s in states_needs_self_loops:
+ ans.append([s, s, disambig_token, disambig_word, 0])
+
+ return arcs + ans
+
+
+def lexicon_to_fst(
+ lexicon: Lexicon,
+ token2id: Dict[str, int],
+ word2id: Dict[str, int],
+ sil_token: str = "!SIL",
+ sil_prob: float = 0.5,
+ need_self_loops: bool = False,
+) -> k2.Fsa:
+ """Convert a lexicon to an FST (in k2 format) with optional silence at
+ the beginning and end of each word.
+
+ Args:
+ lexicon:
+ The input lexicon. See also :func:`read_lexicon`
+ token2id:
+ A dict mapping tokens to IDs.
+ word2id:
+ A dict mapping words to IDs.
+ sil_token:
+ The silence token.
+ sil_prob:
+ The probability for adding a silence at the beginning and end
+ of the word.
+ need_self_loops:
+ If True, add self-loop to states with non-epsilon output symbols
+ on at least one arc out of the state. The input label for this
+ self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
+ Returns:
+ Return an instance of `k2.Fsa` representing the given lexicon.
+ """
+ assert sil_prob > 0.0 and sil_prob < 1.0
+ # CAUTION: we use score, i.e, negative cost.
+ sil_score = math.log(sil_prob)
+ no_sil_score = math.log(1.0 - sil_prob)
+
+ start_state = 0
+ loop_state = 1 # words enter and leave from here
+ sil_state = 2 # words terminate here when followed by silence; this state
+ # has a silence transition to loop_state.
+ next_state = 3 # the next un-allocated state, will be incremented as we go.
+ arcs = []
+
+ # assert token2id[""] == 0
+ # assert word2id[""] == 0
+
+ eps = 0
+ sil_token = word2id[sil_token]
+
+ arcs.append([start_state, loop_state, eps, eps, no_sil_score])
+ arcs.append([start_state, sil_state, eps, eps, sil_score])
+ arcs.append([sil_state, loop_state, sil_token, eps, 0])
+
+ for word, tokens in lexicon:
+ assert len(tokens) > 0, f"{word} has no pronunciations"
+ cur_state = loop_state
+
+ word = word2id[word]
+ tokens = [word2id[i] for i in tokens]
+
+ for i in range(len(tokens) - 1):
+ w = word if i == 0 else eps
+ arcs.append([cur_state, next_state, tokens[i], w, 0])
+
+ cur_state = next_state
+ next_state += 1
+
+ # now for the last token of this word
+ # It has two out-going arcs, one to the loop state,
+ # the other one to the sil_state.
+ i = len(tokens) - 1
+ w = word if i == 0 else eps
+ arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
+ arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
+
+ if need_self_loops:
+ disambig_token = word2id["#0"]
+ disambig_word = word2id["#0"]
+ arcs = add_self_loops(
+ arcs,
+ disambig_token=disambig_token,
+ disambig_word=disambig_word,
+ )
+
+ final_state = next_state
+ arcs.append([loop_state, final_state, -1, -1, 0])
+ arcs.append([final_state])
+
+ arcs = sorted(arcs, key=lambda arc: arc[0])
+ arcs = [[str(i) for i in arc] for arc in arcs]
+ arcs = [" ".join(arc) for arc in arcs]
+ arcs = "\n".join(arcs)
+
+ fsa = k2.Fsa.from_str(arcs, acceptor=False)
+ return fsa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("lm_dir")
+
+
+def main():
+ args = parser.parse_args()
+
+ out_dir = Path(args.lm_dir)
+ lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"]
+ names = ["frames", "transcript"]
+ sil_token = "!SIL"
+ sil_prob = 0.5
+
+ for name, lexicon_filename in zip(names, lexicon_filenames):
+ lexicon = read_lexicon(lexicon_filename)
+ tokens = get_words(lexicon)
+ words = get_words(lexicon)
+ new_lexicon = []
+ for lexicon_item in lexicon:
+ new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
+ lexicon = new_lexicon
+
+ lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
+
+ for i in range(max_disambig + 1):
+ disambig = f"#{i}"
+ assert disambig not in tokens
+ tokens.append(f"#{i}")
+
+ tokens = [""] + tokens
+ words = ["eps"] + words + ["#0", "!SIL"]
+
+ token2id = generate_id_map(tokens)
+ word2id = generate_id_map(words)
+
+ write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
+ write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
+ write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
+
+ L = lexicon_to_fst(
+ lexicon,
+ token2id=word2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ )
+
+ L_disambig = lexicon_to_fst(
+ lexicon_disambig,
+ token2id=word2id,
+ word2id=word2id,
+ sil_token=sil_token,
+ sil_prob=sil_prob,
+ need_self_loops=True,
+ )
+ torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
+ torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
+
+ if False:
+ # Just for debugging, will remove it
+ L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
+ L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
+ L_disambig.labels_sym = L.labels_sym
+ L_disambig.aux_labels_sym = L.aux_labels_sym
+ L.draw(out_dir / "L.png", title="L")
+ L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
+
+
+main()
diff --git a/egs/fluent_speech_commands/SLU/prepare.sh b/egs/fluent_speech_commands/SLU/prepare.sh
new file mode 100755
index 000000000..3ff339d91
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/prepare.sh
@@ -0,0 +1,103 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=1
+stop_stage=5
+
+data_dir=path/to/fluent/speech/commands
+target_root_dir=data/
+
+lang_dir=${target_root_dir}/lang_phone
+lm_dir=${target_root_dir}/lm
+manifest_dir=${target_root_dir}/manifests
+fbanks_dir=${target_root_dir}/fbanks
+
+. shared/parse_options.sh || exit 1
+
+mkdir -p $lang_dir
+mkdir -p $lm_dir
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "data_dir: $data_dir"
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare slu manifest"
+ mkdir -p $manifest_dir
+ lhotse prepare slu $data_dir $manifest_dir
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute fbank for SLU"
+ mkdir -p $fbanks_dir
+ python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare lang"
+ # NOTE: " SIL" is added for implementation convenience
+ # as the graph compiler code requires that there is a OOV word
+ # in the lexicon.
+ python ./local/generate_lexicon.py $data_dir $lm_dir
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Train LM"
+ # We use a unigram G
+ ./shared/make_kn_lm.py \
+ -ngram-order 1 \
+ -text $lm_dir/words_transcript.txt \
+ -lm $lm_dir/G_transcript.arpa
+
+ ./shared/make_kn_lm.py \
+ -ngram-order 1 \
+ -text $lm_dir/words_frames.txt \
+ -lm $lm_dir/G_frames.arpa
+
+ python ./local/prepare_lang.py $lm_dir
+
+ if [ ! -f $lm_dir/G_transcript.fst.txt ]; then
+ python -m kaldilm \
+ --read-symbol-table="$lm_dir/words_transcript.txt" \
+ $lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt
+ fi
+
+ if [ ! -f $lm_dir/G_frames.fst.txt ]; then
+ python -m kaldilm \
+ --read-symbol-table="$lm_dir/words_frames.txt" \
+ $lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt
+ fi
+
+ mkdir -p $lm_dir/frames
+ mkdir -p $lm_dir/transcript
+
+ chmod -R +777 .
+
+ for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt;
+ do
+ j=${i//"_frames"/}
+ mv "$lm_dir/$i" $lm_dir/frames/$j
+ done
+
+ for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt;
+ do
+ j=${i//"_transcript"/}
+ mv "$lm_dir/$i" $lm_dir/transcript/$j
+ done
+fi
+
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compile HLG"
+ ./local/compile_hlg.py --lang-dir $lm_dir/frames
+ ./local/compile_hlg.py --lang-dir $lm_dir/transcript
+
+fi
diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared
new file mode 120000
index 000000000..32a374a7f
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/shared
@@ -0,0 +1 @@
+../../icefall/shared/
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/__init__.py b/egs/fluent_speech_commands/SLU/transducer/__init__.py
new file mode 100755
index 000000000..e69de29bb
diff --git a/egs/fluent_speech_commands/SLU/transducer/beam_search.py b/egs/fluent_speech_commands/SLU/transducer/beam_search.py
new file mode 100755
index 000000000..a16aa0123
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/beam_search.py
@@ -0,0 +1,71 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+from transducer.model import Transducer
+
+
+def greedy_search(
+ model: Transducer, encoder_out: torch.Tensor, id2word: dict
+) -> List[str]:
+ """
+ Args:
+ model:
+ An instance of `Transducer`.
+ encoder_out:
+ A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
+ Returns:
+ Return the decoded result.
+ """
+ assert encoder_out.ndim == 3
+
+ # support only batch_size == 1 for now
+ assert encoder_out.size(0) == 1, encoder_out.size(0)
+ blank_id = model.decoder.blank_id
+ device = model.device
+
+ sos = torch.tensor([blank_id], device=device).reshape(1, 1)
+ decoder_out, (h, c) = model.decoder(sos)
+ T = encoder_out.size(1)
+ t = 0
+ hyp = []
+ max_u = 1000 # terminate after this number of steps
+ u = 0
+
+ while t < T and u < max_u:
+ # fmt: off
+ current_encoder_out = encoder_out[:, t:t+1, :]
+ # fmt: on
+ logits = model.joiner(current_encoder_out, decoder_out)
+
+ log_prob = logits.log_softmax(dim=-1)
+ # log_prob is (N, 1, 1)
+ # TODO: Use logits.argmax()
+ y = log_prob.argmax()
+ if y != blank_id:
+ hyp.append(y.item())
+ y = y.reshape(1, 1)
+ decoder_out, (h, c) = model.decoder(y, (h, c))
+ u += 1
+ else:
+ t += 1
+ # id2word = {1: "YES", 2: "NO"}
+
+ hyp = [id2word[i] for i in hyp]
+
+ return hyp
diff --git a/egs/fluent_speech_commands/SLU/transducer/conformer.py b/egs/fluent_speech_commands/SLU/transducer/conformer.py
new file mode 120000
index 000000000..8be0dc864
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/conformer.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/decode.py b/egs/fluent_speech_commands/SLU/transducer/decode.py
new file mode 100755
index 000000000..ba2b9aaea
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/decode.py
@@ -0,0 +1,346 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from transducer.beam_search import greedy_search
+from transducer.conformer import Conformer
+from transducer.decoder import Decoder
+from transducer.joiner import Joiner
+from transducer.model import Transducer
+from transducer.slu_datamodule import SluDataModule
+
+from icefall.checkpoint import average_checkpoints, load_checkpoint
+from icefall.env import get_env_info
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ write_error_stats,
+)
+
+
+def get_id2word(params):
+ id2word = {}
+
+ # 0 is blank
+ id = 1
+ try:
+ with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if len(line.strip()) > 0:
+ id2word[id] = line.split()[0]
+ id += 1
+ except:
+ pass
+
+ return id2word
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=6,
+ help="It specifies the checkpoint to use for decoding."
+ "Note: Epoch counts from 0.",
+ )
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch'. ",
+ )
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer/exp",
+ help="Directory from which to load the checkpoints",
+ )
+ parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "feature_dim": 23,
+ "lang_dir": Path("data/lm/frames"),
+ # encoder/decoder params
+ "vocab_size": 3, # blank, yes, no
+ "blank_id": 0,
+ "embedding_dim": 32,
+ "hidden_dim": 16,
+ "num_decoder_layers": 4,
+ }
+ )
+
+ vocab_size = 1
+ with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if (
+ len(line.strip()) > 0
+ ): # and '' not in line and '' not in line and '' not in line:
+ vocab_size += 1
+ params.vocab_size = vocab_size
+
+ return params
+
+
+def decode_one_batch(
+ params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
+) -> List[List[int]]:
+ """Decode one batch and return the result in a list-of-list.
+ Each sub list contains the word IDs for an utterance in the batch.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+
+ - params.method is "1best", it uses 1best decoding.
+ - params.method is "nbest", it uses nbest decoding.
+
+ model:
+ The neural model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
+ Returns:
+ Return the decoding result. `len(ans)` == batch size.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+ feature_lens = batch["supervisions"]["num_frames"].to(device)
+
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+ hyps = []
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word)
+ hyps.append(hyp)
+
+ # hyps = [[word_table[i] for i in ids] for ids in hyps]
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+) -> List[Tuple[List[int], List[int]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ Returns:
+ Return a tuple contains two elements (ref_text, hyp_text):
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ results = []
+
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ id2word = get_id2word(params)
+
+ results = []
+ for batch_idx, batch in enumerate(dl):
+ texts = [
+ " ".join(a.supervisions[0].custom["frames"])
+ for a in batch["supervisions"]["cut"]
+ ]
+ texts = [
+ " " + a.replace("change language", "change_language") + " "
+ for a in texts
+ ]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps = decode_one_batch(
+ params=params, model=model, batch=batch, id2word=id2word
+ )
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results.extend(this_batch)
+
+ num_cuts += len(batch["supervisions"]["text"])
+
+ if batch_idx % 100 == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ exp_dir: Path,
+ test_set_name: str,
+ results: List[Tuple[List[int], List[int]]],
+) -> None:
+ """Save results to `exp_dir`.
+ Args:
+ exp_dir:
+ The output directory. This function create the following files inside
+ this directory:
+
+ - recogs-{test_set_name}.text
+
+ It contains the reference and hypothesis results, like below::
+
+ ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
+ hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
+ ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
+ hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
+
+ - errs-{test_set_name}.txt
+
+ It contains the detailed WER.
+ test_set_name:
+ The name of the test set, which will be part of the result filename.
+ results:
+ A list of tuples, each of which contains (ref_words, hyp_words).
+ Returns:
+ Return None.
+ """
+ recog_path = exp_dir / f"recogs-{test_set_name}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = exp_dir / f"errs-{test_set_name}.txt"
+ with open(errs_filename, "w") as f:
+ write_error_stats(f, f"{test_set_name}", results)
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+
+def get_transducer_model(params: AttributeDict):
+ # encoder = Tdnn(
+ # num_features=params.feature_dim,
+ # output_dim=params.hidden_dim,
+ # )
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.hidden_dim,
+ )
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.embedding_dim,
+ blank_id=params.blank_id,
+ num_layers=params.num_decoder_layers,
+ hidden_dim=params.hidden_dim,
+ embedding_dropout=0.4,
+ rnn_dropout=0.4,
+ )
+ joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
+ transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
+ return transducer
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ SluDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+ params["env_info"] = get_env_info()
+
+ setup_logger(f"{params.exp_dir}/log/log-decode")
+ logging.info("Decoding started")
+ logging.info(params)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ model = get_transducer_model(params)
+
+ if params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames))
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ slu = SluDataModule(args)
+ test_dl = slu.test_dataloaders()
+ results = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ )
+
+ test_set_name = str(args.feature_dir).split("/")[-2]
+ save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/fluent_speech_commands/SLU/transducer/decoder.py b/egs/fluent_speech_commands/SLU/transducer/decoder.py
new file mode 120000
index 000000000..e99310f91
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/decoder.py
@@ -0,0 +1 @@
+../../../yesno/ASR/transducer/decoder.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/joiner.py b/egs/fluent_speech_commands/SLU/transducer/joiner.py
new file mode 120000
index 000000000..75fa64868
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/joiner.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/model.py b/egs/fluent_speech_commands/SLU/transducer/model.py
new file mode 120000
index 000000000..10f6ddad1
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/model.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py
new file mode 100755
index 000000000..fa715abdd
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py
@@ -0,0 +1,289 @@
+# Copyright 2021 Piotr Żelasko
+# 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import List
+
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from torch.utils.data import DataLoader
+
+from icefall.dataset.datamodule import DataModule
+from icefall.utils import str2bool
+
+
+class SluDataModule(DataModule):
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ """
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ super().add_arguments(parser)
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--feature-dir",
+ type=Path,
+ default=Path("data/fbanks"),
+ help="Path to directory with train/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=30.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=False,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=10,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ def train_dataloaders(self) -> DataLoader:
+ logging.info("About to get train cuts")
+ cuts_train = self.train_cuts()
+
+ logging.info("About to create train dataset")
+ transforms = []
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(
+ FbankConfig(sampling_rate=8000, num_mel_bins=23)
+ ),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=True,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self) -> DataLoader:
+ logging.info("About to get valid cuts")
+ cuts_valid = self.valid_cuts()
+
+ logging.debug("About to create valid dataset")
+ valid = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create valid dataloader")
+ valid_dl = DataLoader(
+ valid,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ )
+ return valid_dl
+
+ def test_dataloaders(self) -> DataLoader:
+ logging.info("About to get test cuts")
+ cuts_test = self.test_cuts()
+
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts_test,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ cuts_train = load_manifest_lazy(
+ self.args.feature_dir / "slu_cuts_train.jsonl.gz"
+ )
+ return cuts_train
+
+ @lru_cache()
+ def valid_cuts(self) -> List[CutSet]:
+ logging.info("About to get valid cuts")
+ cuts_valid = load_manifest_lazy(
+ self.args.feature_dir / "slu_cuts_valid.jsonl.gz"
+ )
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
+ return cuts_test
diff --git a/egs/fluent_speech_commands/SLU/transducer/subsampling.py b/egs/fluent_speech_commands/SLU/transducer/subsampling.py
new file mode 120000
index 000000000..fd7ca8b30
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/subsampling.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_conformer.py b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py
new file mode 120000
index 000000000..3060dd70c
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/test_conformer.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_decoder.py b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py
new file mode 120000
index 000000000..d1bc718ce
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py
@@ -0,0 +1 @@
+../../../yesno/ASR/transducer/test_decoder.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_joiner.py b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py
new file mode 120000
index 000000000..248222a8a
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/test_joiner.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/test_transducer.py b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py
new file mode 120000
index 000000000..df104bad7
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer/test_transducer.py
\ No newline at end of file
diff --git a/egs/fluent_speech_commands/SLU/transducer/train.py b/egs/fluent_speech_commands/SLU/transducer/train.py
new file mode 100755
index 000000000..a59c0b754
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/train.py
@@ -0,0 +1,625 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+from pathlib import Path
+from shutil import copyfile
+from typing import List, Optional, Tuple
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import torch.optim as optim
+from lhotse.utils import fix_random_seed
+from slu_datamodule import SluDataModule
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+from transducer.conformer import Conformer
+
+# from torch.utils.tensorboard import SummaryWriter
+from transducer.decoder import Decoder
+from transducer.joiner import Joiner
+from transducer.model import Transducer
+
+from icefall.checkpoint import load_checkpoint
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+
+def get_word2id(params):
+ word2id = {}
+
+ # 0 is blank
+ id = 1
+ with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if len(line.strip()) > 0:
+ word2id[line.split()[0]] = id
+ id += 1
+
+ return word2id
+
+
+def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ Returns:
+ Return a ragged tensor containing the corresponding word ID.
+ """
+ # blank is 0
+ word_ids = []
+ for t in texts:
+ words = t.split()
+ ids = [word2id[w] for w in words]
+ word_ids.append(ids)
+
+ return k2.RaggedTensor(word_ids)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=7,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=0,
+ help="""Resume training from from this epoch.
+ If it is positive, it will load checkpoint from
+ tdnn/exp/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="transducer/exp",
+ help="Directory to save results",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ is saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - lr: It specifies the initial learning rate
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - weight_decay: The weight_decay for the optimizer.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - start_epoch: If it is not zero, load checkpoint `start_epoch-1`
+ and continue training from that checkpoint.
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+
+ """
+ params = AttributeDict(
+ {
+ "lr": 1e-4,
+ "feature_dim": 23,
+ "weight_decay": 1e-6,
+ "start_epoch": 0,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 100,
+ "reset_interval": 20,
+ "valid_interval": 3000,
+ "exp_dir": Path("transducer/exp"),
+ "lang_dir": Path("data/lm/frames"),
+ # encoder/decoder params
+ "vocab_size": 3, # blank, yes, no
+ "blank_id": 0,
+ "embedding_dim": 32,
+ "hidden_dim": 16,
+ "num_decoder_layers": 4,
+ }
+ )
+
+ vocab_size = 1
+ with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
+ for line in lexicon_file:
+ if (
+ len(line.strip()) > 0
+ ): # and '' not in line and '' not in line and '' not in line:
+ vocab_size += 1
+ params.vocab_size = vocab_size
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+) -> None:
+ """Load checkpoint from file.
+
+ If params.start_epoch is positive, it will load the checkpoint from
+ `params.start_epoch - 1`. Otherwise, this function does nothing.
+
+ Apart from loading state dict for `model`, `optimizer` and `scheduler`,
+ it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The learning rate scheduler we are using.
+ Returns:
+ Return None.
+ """
+ if params.start_epoch <= 0:
+ return
+
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Tdnn in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ feature_lens = batch["supervisions"]["num_frames"].to(device)
+
+ texts = [
+ " ".join(a.supervisions[0].custom["frames"])
+ for a in batch["supervisions"]["cut"]
+ ]
+ texts = [
+ " " + a.replace("change language", "change_language") + " "
+ for a in texts
+ ]
+ labels = get_labels(texts, word2ids).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ loss = model(x=feature, x_lens=feature_lens, y=labels)
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ info["frames"] = feature.size(0)
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: nn.Module,
+ valid_dl: torch.utils.data.DataLoader,
+ word2ids,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process. The validation loss
+ is saved in `params.valid_loss`.
+ """
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ batch=batch,
+ is_training=False,
+ word2ids=word2ids,
+ )
+ assert loss.requires_grad is False
+
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ word2ids,
+ tb_writer: None,
+ world_size: int = 1,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ loss, loss_info = compute_loss(
+ params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
+ )
+ # summary stats.
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ optimizer.zero_grad()
+ loss.backward()
+ clip_grad_norm_(model.parameters(), 5.0, 2.0)
+ optimizer.step()
+
+ if batch_idx % params.log_interval == 0:
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}"
+ )
+ if batch_idx % params.log_interval == 0:
+
+ if tb_writer is not None:
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+
+ if batch_idx > 0 and batch_idx % params.valid_interval == 0:
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ word2ids=word2ids,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer,
+ "train/valid_",
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def get_transducer_model(params: AttributeDict):
+ encoder = Conformer(
+ num_features=params.feature_dim,
+ output_dim=params.hidden_dim,
+ )
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ embedding_dim=params.embedding_dim,
+ blank_id=params.blank_id,
+ num_layers=params.num_decoder_layers,
+ hidden_dim=params.hidden_dim,
+ embedding_dropout=0.4,
+ rnn_dropout=0.4,
+ )
+ joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
+ transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
+
+ return transducer
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+
+ params.update(vars(args))
+ params["env_info"] = get_env_info()
+
+ word2ids = get_word2id(params)
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+ logging.info(params)
+
+ # if args.tensorboard and rank == 0:
+ # tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ # else:
+ # tb_writer = None
+ tb_writer = None
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ else:
+ device = torch.device("cpu")
+ logging.info(f"device: {device}")
+
+ model = get_transducer_model(params)
+
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ model = DDP(model, device_ids=[rank])
+
+ model.device = device
+
+ optimizer = optim.Adam(
+ model.parameters(),
+ lr=params.lr,
+ weight_decay=params.weight_decay,
+ )
+
+ if checkpoints:
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ slu = SluDataModule(args)
+ train_dl = slu.train_dataloaders()
+
+ # There are only 60 waves: 30 files are used for training
+ # and the remaining 30 files are used for testing.
+ # We use test data as validation.
+ valid_dl = slu.test_dataloaders()
+
+ for epoch in range(params.start_epoch, params.num_epochs):
+ fix_random_seed(params.seed + epoch)
+ train_dl.sampler.set_epoch(epoch)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ word2ids=word2ids,
+ )
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ optimizer=optimizer,
+ scheduler=None,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ SluDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/fluent_speech_commands/SLU/transducer/transformer.py b/egs/fluent_speech_commands/SLU/transducer/transformer.py
new file mode 120000
index 000000000..214afed39
--- /dev/null
+++ b/egs/fluent_speech_commands/SLU/transducer/transformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/transformer.py
\ No newline at end of file
diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py
index 231aca7f1..42ed44fdd 100755
--- a/icefall/shared/make_kn_lm.py
+++ b/icefall/shared/make_kn_lm.py
@@ -33,7 +33,7 @@ parser.add_argument(
"-ngram-order",
type=int,
default=4,
- choices=[2, 3, 4, 5, 6, 7],
+ choices=[1, 2, 3, 4, 5, 6, 7],
help="Order of n-gram",
)
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
@@ -105,7 +105,7 @@ class NgramCounts:
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
def __init__(self, ngram_order, bos_symbol="", eos_symbol=""):
- assert ngram_order >= 2
+ assert ngram_order >= 1
self.ngram_order = ngram_order
self.bos_symbol = bos_symbol
@@ -169,7 +169,10 @@ class NgramCounts:
with open(filename, encoding=default_encoding) as fp:
for line in fp:
line = line.strip(strip_chars)
- self.add_raw_counts_from_line(line)
+ if self.ngram_order == 1:
+ self.add_raw_counts_from_line(line.split()[0])
+ else:
+ self.add_raw_counts_from_line(line)
lines_processed += 1
if lines_processed == 0 or args.verbose > 0:
print(
From b9e6327adfe0def4989ad508d0df8f6347da2c0f Mon Sep 17 00:00:00 2001
From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com>
Date: Sat, 3 Feb 2024 07:25:27 +0900
Subject: [PATCH 046/123] Fixing torch.ctc err (#1485)
* fixing torch.ctc err
* Move targets & lengths to CPU
---
egs/librispeech/ASR/zipformer/model.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py
index f2f86af47..73009d35c 100644
--- a/egs/librispeech/ASR/zipformer/model.py
+++ b/egs/librispeech/ASR/zipformer/model.py
@@ -164,9 +164,9 @@ class AsrModel(nn.Module):
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
- targets=targets,
- input_lengths=encoder_out_lens,
- target_lengths=target_lengths,
+ targets=targets.cpu(),
+ input_lengths=encoder_out_lens.cpu(),
+ target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
From a813186f6463b4d3f6d73460dd1a57b2e975b803 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Mon, 5 Feb 2024 12:47:52 +0800
Subject: [PATCH 047/123] minor fix for docstr and default param. (#1490)
* Update train.py and README.md
---
README.md | 3 +++
egs/aishell/ASR/whisper/train.py | 6 +++---
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index cc817702b..770066166 100644
--- a/README.md
+++ b/README.md
@@ -74,6 +74,9 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod
- LSTM-based Predictor
- [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/)
+#### Whisper
+ - [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.)
+
If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details.
We would like to highlight the performance of some of the recipes here.
diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py
index d16793eb2..073b23713 100755
--- a/egs/aishell/ASR/whisper/train.py
+++ b/egs/aishell/ASR/whisper/train.py
@@ -19,7 +19,7 @@
Usage:
#fine-tuning with deepspeed zero stage 1
-torchrun --nproc-per-node 8 ./whisper/train.py \
+torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
@@ -28,7 +28,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
--deepspeed_config ./whisper/ds_config_zero1.json
# fine-tuning with ddp
-torchrun --nproc-per-node 8 ./whisper/train.py \
+torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--manifest-dir data/fbank_whisper \
@@ -136,7 +136,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
- default="pruned_transducer_stateless7/exp",
+ default="whisper/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
From 777074046d6cb0f9b7ed7a98115299c04b8bcdf1 Mon Sep 17 00:00:00 2001
From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com>
Date: Tue, 6 Feb 2024 18:25:43 +0800
Subject: [PATCH 048/123] Fine-tune recipe for Zipformer (#1484)
1. support finetune zipformer
2. update the usage; set a very large batch count
---
.../local/compute_fbank_gigaspeech_splits.py | 18 +-
.../ASR/tdnn_lstm_ctc/asr_datamodule.py | 17 +-
.../ASR/zipformer/decode_gigaspeech.py | 1114 ++++++++++++
egs/librispeech/ASR/zipformer/finetune.py | 1521 +++++++++++++++++
4 files changed, 2664 insertions(+), 6 deletions(-)
create mode 100755 egs/librispeech/ASR/zipformer/decode_gigaspeech.py
create mode 100755 egs/librispeech/ASR/zipformer/finetune.py
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
index 1c71be0f9..176eb8a84 100755
--- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
@@ -51,6 +51,14 @@ def get_parser():
"Determines batch size dynamically.",
)
+ parser.add_argument(
+ "--subset",
+ type=str,
+ default="XL",
+ choices=["XL", "L", "M", "S", "XS"],
+ help="Which subset to work with",
+ )
+
parser.add_argument(
"--num-splits",
type=int,
@@ -76,7 +84,7 @@ def get_parser():
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
- output_dir = "data/fbank/XL_split"
+ output_dir = f"data/fbank/{args.subset}_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
@@ -96,15 +104,15 @@ def compute_fbank_gigaspeech_splits(args):
logging.info(f"device: {device}")
for i in range(start, stop):
- idx = f"{i + 1}".zfill(num_digits)
+ idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
- cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
+ cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
- raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
+ raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@@ -113,7 +121,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
- storage_path=f"{output_dir}/feats_XL_{idx}",
+ storage_path=f"{output_dir}/feats_{args.subset}_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index dd9e9ef1f..814390ad6 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -1,4 +1,4 @@
-# Copyright 2021 Piotr Żelasko
+# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
@@ -475,3 +475,18 @@ class LibriSpeechAsrDataModule:
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)
+
+ @lru_cache()
+ def gigaspeech_subset_small_cuts(self) -> CutSet:
+ logging.info("About to get Gigaspeech subset-S cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
+
+ @lru_cache()
+ def gigaspeech_dev_cuts(self) -> CutSet:
+ logging.info("About to get Gigaspeech dev cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
+
+ @lru_cache()
+ def gigaspeech_test_cuts(self) -> CutSet:
+ logging.info("About to get Gigaspeech test cuts")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
new file mode 100755
index 000000000..3cda337c0
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py
@@ -0,0 +1,1114 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+ gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
+
+ dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
+ test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py
new file mode 100755
index 000000000..843d103cc
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/finetune.py
@@ -0,0 +1,1521 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# Fine-tune without mux (i.e not mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 0 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+# Fine-tune without mux (i.e mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 1 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ # Note that we add a very large constant here to make the ScheduledFloat
+ # variable as their end value.
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ ) + 100000
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--do-finetune",
+ type=str2bool,
+ default=True,
+ help="If true, finetune from a pre-trained checkpoint",
+ )
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune. This is useful
+ if you want to maintain the performance on the original domain
+ """,
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (path to a .pt file)",
+ )
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.0045,
+ help="""The base learning rate.
+ It is set to a very small value as we are doing fine-tuning""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000.0,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100.0,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+ init_modules (list[str]): List of modules to be initialized
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dls: torch.utils.data.DataLoader,
+ valid_sets: List[str],
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ for valid_set, valid_dl in zip(valid_sets, valid_dls):
+ logging.info(f"Computing validation loss on {valid_set}")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(
+ f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ # load model parameters for model fine-tuning
+ if params.do_finetune:
+ assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules
+ )
+ # Need to update the model_avg if use initialisation
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ else:
+ # resuming training
+ assert params.start_epoch > 1, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
+ if params.use_mux:
+ librispeech_cuts = librispeech.train_all_shuf_cuts()
+ train_cuts = CutSet.mux(
+ gigaspeech_cuts, # num cuts = 688182
+ librispeech_cuts, # num cuts = 843723
+ weights=[688182, 843723],
+ stop_early=True,
+ )
+ else:
+ train_cuts = gigaspeech_cuts
+ logging.info(train_cuts)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+
+ valid_sets = ["librispeech", "gigaspeech"]
+ valid_dls = [
+ librispeech.valid_dataloaders(valid_cuts),
+ librispeech.valid_dataloaders(gigaspeech_dev_cuts),
+ ]
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dls=valid_dls,
+ valid_sets=valid_sets,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
From 4ed88d948494d060a8c2827584d71bff616c3ca9 Mon Sep 17 00:00:00 2001
From: Tiance Wang
Date: Wed, 7 Feb 2024 10:16:02 +0800
Subject: [PATCH 049/123] Update shared (#1487)
There should be one more ../
---
egs/fluent_speech_commands/SLU/shared | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared
index 32a374a7f..9115c7e17 120000
--- a/egs/fluent_speech_commands/SLU/shared
+++ b/egs/fluent_speech_commands/SLU/shared
@@ -1 +1 @@
-../../icefall/shared/
\ No newline at end of file
+../../../icefall/shared/
From 711d6bc46297ba799174bae4f3ebb965224ad76f Mon Sep 17 00:00:00 2001
From: Wei Kang
Date: Fri, 9 Feb 2024 10:44:19 +0800
Subject: [PATCH 050/123] Refactor prepare.sh in librispeech (#1493)
* Refactor prepare.sh in librispeech, break it into three parts, prepare.sh (basic, minimal requirement for transducer), prepare_lm.sh (ngram & nnlm staff), prepare_mmi.sh (for MMI training).
---
egs/librispeech/ASR/RESULTS.md | 2 +-
egs/librispeech/ASR/generate-lm.sh | 20 --
egs/librispeech/ASR/local/train_bpe_model.py | 15 +
egs/librispeech/ASR/prepare.sh | 334 ++++---------------
egs/librispeech/ASR/prepare_lm.sh | 262 +++++++++++++++
egs/librispeech/ASR/prepare_mmi.sh | 45 +++
6 files changed, 393 insertions(+), 285 deletions(-)
delete mode 100755 egs/librispeech/ASR/generate-lm.sh
create mode 100755 egs/librispeech/ASR/prepare_lm.sh
create mode 100755 egs/librispeech/ASR/prepare_mmi.sh
diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md
index ebf5e89c4..ee5422aba 100644
--- a/egs/librispeech/ASR/RESULTS.md
+++ b/egs/librispeech/ASR/RESULTS.md
@@ -1526,7 +1526,7 @@ done
You may also decode using LODR + LM shallow fusion. This decoding method is proposed in .
It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be
-generated by `generate-lm.sh`, or you may download it from .
+generated by `prepare_lm.sh` at stage 4, or you may download it from .
The decoding command is as follows:
diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh
deleted file mode 100755
index dacd276d1..000000000
--- a/egs/librispeech/ASR/generate-lm.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/usr/bin/env bash
-
-lang_dir=data/lang_bpe_500
-
-for ngram in 2 3 4 5; do
- if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
- ./shared/make_kn_lm.py \
- -ngram-order ${ngram} \
- -text $lang_dir/transcript_tokens.txt \
- -lm $lang_dir/${ngram}gram.arpa
- fi
-
- if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
- python3 -m kaldilm \
- --read-symbol-table="$lang_dir/tokens.txt" \
- --disambig-symbol='#0' \
- --max-order=${ngram} \
- $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
- fi
-done
diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py
index 43142aee4..5979d5b98 100755
--- a/egs/librispeech/ASR/local/train_bpe_model.py
+++ b/egs/librispeech/ASR/local/train_bpe_model.py
@@ -28,6 +28,7 @@
import argparse
import shutil
from pathlib import Path
+from typing import Dict
import sentencepiece as spm
@@ -57,6 +58,18 @@ def get_args():
return parser.parse_args()
+def generate_tokens(lang_dir: Path):
+ """
+ Generate the tokens.txt from a bpe model.
+ """
+ sp = spm.SentencePieceProcessor()
+ sp.load(str(lang_dir / "bpe.model"))
+ token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
+ with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f:
+ for sym, i in token2id.items():
+ f.write(f"{sym} {i}\n")
+
+
def main():
args = get_args()
vocab_size = args.vocab_size
@@ -95,6 +108,8 @@ def main():
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
+ generate_tokens(lang_dir)
+
if __name__ == "__main__":
main()
diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh
index 4a5072cc0..40dc3260d 100755
--- a/egs/librispeech/ASR/prepare.sh
+++ b/egs/librispeech/ASR/prepare.sh
@@ -6,8 +6,21 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
-stage=-1
-stop_stage=100
+# run step 0 to step 5 by default
+stage=0
+stop_stage=5
+
+# Note: This script just prepare the minimal requirements that needed by a
+# transducer training with bpe units.
+#
+# If you want to use ngram or nnlm, please continue running prepare_lm.sh after
+# you succeed running this script.
+#
+# This script also contains the steps to generate phone based units, but they
+# will not run automatically, you can generate the phone based units by
+# bash prepare.sh --stage -1 --stop-stage -1
+# bash prepare.sh --stage 6 --stop-stage 6
+
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
@@ -17,6 +30,18 @@ stop_stage=100
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+#
+# lm directory is not necessary for transducer training with bpe units, but it
+# is needed by phone based modeling, you can download it by running
+# bash prepare.sh --stage -1 --stop-stage -1
+# then you can see the following files in the directory.
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
@@ -28,14 +53,7 @@ stop_stage=100
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
# - librispeech-lm-norm.txt.gz
-#
-# - $dl_dir/musan
-# This directory contains the following directories downloaded from
-# http://www.openslr.org/17/
-#
-# - music
-# - noise
-# - speech
+
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
@@ -60,6 +78,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
+log "Running prepare.sh"
+
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
@@ -159,13 +179,49 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
- log "Stage 5: Prepare phone based lang"
+ log "Stage 5: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ log "Generate data for BPE training"
+ files=$(
+ find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
+ done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
- (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
- cat - $dl_dir/lm/librispeech-lexicon.txt |
- sort | uniq > $lang_dir/lexicon.txt
+ if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then
+ log "No lexicon file in $dl_dir/lm, please run :"
+ log "prepare.sh --stage -1 --stop-stage -1"
+ exit -1
+ fi
+
+ if [ ! -f $lang_dir/lexicon.txt ]; then
+ (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+ cat - $dl_dir/lm/librispeech-lexicon.txt |
+ sort | uniq > $lang_dir/lexicon.txt
+ fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
@@ -187,253 +243,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
$lang_dir/L_disambig.fst
fi
fi
-
-
-if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
- log "Stage 6: Prepare BPE based lang"
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
- mkdir -p $lang_dir
- # We reuse words.txt from phone based lexicon
- # so that the two can share G.pt later.
- cp data/lang_phone/words.txt $lang_dir
-
- if [ ! -f $lang_dir/transcript_words.txt ]; then
- log "Generate data for BPE training"
- files=$(
- find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
- )
- for f in ${files[@]}; do
- cat $f | cut -d " " -f 2-
- done > $lang_dir/transcript_words.txt
- fi
-
- if [ ! -f $lang_dir/bpe.model ]; then
- ./local/train_bpe_model.py \
- --lang-dir $lang_dir \
- --vocab-size $vocab_size \
- --transcript $lang_dir/transcript_words.txt
- fi
-
- if [ ! -f $lang_dir/L_disambig.pt ]; then
- ./local/prepare_lang_bpe.py --lang-dir $lang_dir
-
- log "Validating $lang_dir/lexicon.txt"
- ./local/validate_bpe_lexicon.py \
- --lexicon $lang_dir/lexicon.txt \
- --bpe-model $lang_dir/bpe.model
- fi
-
- if [ ! -f $lang_dir/L.fst ]; then
- log "Converting L.pt to L.fst"
- ./shared/convert-k2-to-openfst.py \
- --olabels aux_labels \
- $lang_dir/L.pt \
- $lang_dir/L.fst
- fi
-
- if [ ! -f $lang_dir/L_disambig.fst ]; then
- log "Converting L_disambig.pt to L_disambig.fst"
- ./shared/convert-k2-to-openfst.py \
- --olabels aux_labels \
- $lang_dir/L_disambig.pt \
- $lang_dir/L_disambig.fst
- fi
- done
-fi
-
-if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
- log "Stage 7: Prepare bigram token-level P for MMI training"
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
-
- if [ ! -f $lang_dir/transcript_tokens.txt ]; then
- ./local/convert_transcript_words_to_tokens.py \
- --lexicon $lang_dir/lexicon.txt \
- --transcript $lang_dir/transcript_words.txt \
- --oov "" \
- > $lang_dir/transcript_tokens.txt
- fi
-
- if [ ! -f $lang_dir/P.arpa ]; then
- ./shared/make_kn_lm.py \
- -ngram-order 2 \
- -text $lang_dir/transcript_tokens.txt \
- -lm $lang_dir/P.arpa
- fi
-
- if [ ! -f $lang_dir/P.fst.txt ]; then
- python3 -m kaldilm \
- --read-symbol-table="$lang_dir/tokens.txt" \
- --disambig-symbol='#0' \
- --max-order=2 \
- $lang_dir/P.arpa > $lang_dir/P.fst.txt
- fi
- done
-fi
-
-if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
- log "Stage 8: Prepare G"
- # We assume you have installed kaldilm, if not, please install
- # it using: pip install kaldilm
-
- mkdir -p data/lm
- if [ ! -f data/lm/G_3_gram.fst.txt ]; then
- # It is used in building HLG
- python3 -m kaldilm \
- --read-symbol-table="data/lang_phone/words.txt" \
- --disambig-symbol='#0' \
- --max-order=3 \
- $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
- fi
-
- if [ ! -f data/lm/G_4_gram.fst.txt ]; then
- # It is used for LM rescoring
- python3 -m kaldilm \
- --read-symbol-table="data/lang_phone/words.txt" \
- --disambig-symbol='#0' \
- --max-order=4 \
- $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
- fi
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
-
- if [ ! -f $lang_dir/HL.fst ]; then
- ./local/prepare_lang_fst.py \
- --lang-dir $lang_dir \
- --ngram-G ./data/lm/G_3_gram.fst.txt
- fi
- done
-fi
-
-if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
- log "Stage 9: Compile HLG"
- ./local/compile_hlg.py --lang-dir data/lang_phone
-
- # Note If ./local/compile_hlg.py throws OOM,
- # please switch to the following command
- #
- # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/compile_hlg.py --lang-dir $lang_dir
-
- # Note If ./local/compile_hlg.py throws OOM,
- # please switch to the following command
- #
- # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
- done
-fi
-
-# Compile LG for RNN-T fast_beam_search decoding
-if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
- log "Stage 10: Compile LG"
- ./local/compile_lg.py --lang-dir data/lang_phone
-
- for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/compile_lg.py --lang-dir $lang_dir
- done
-fi
-
-if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
- log "Stage 11: Generate LM training data"
-
- for vocab_size in ${vocab_sizes[@]}; do
- log "Processing vocab_size == ${vocab_size}"
- lang_dir=data/lang_bpe_${vocab_size}
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
-
- ./local/prepare_lm_training_data.py \
- --bpe-model $lang_dir/bpe.model \
- --lm-data $dl_dir/lm/librispeech-lm-norm.txt \
- --lm-archive $out_dir/lm_data.pt
- done
-fi
-
-if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
- log "Stage 12: Generate LM validation data"
-
- for vocab_size in ${vocab_sizes[@]}; do
- log "Processing vocab_size == ${vocab_size}"
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
-
- if [ ! -f $out_dir/valid.txt ]; then
- files=$(
- find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
- )
- for f in ${files[@]}; do
- cat $f | cut -d " " -f 2-
- done > $out_dir/valid.txt
- fi
-
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/prepare_lm_training_data.py \
- --bpe-model $lang_dir/bpe.model \
- --lm-data $out_dir/valid.txt \
- --lm-archive $out_dir/lm_data-valid.pt
- done
-fi
-
-if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
- log "Stage 13: Generate LM test data"
-
- for vocab_size in ${vocab_sizes[@]}; do
- log "Processing vocab_size == ${vocab_size}"
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
-
- if [ ! -f $out_dir/test.txt ]; then
- files=$(
- find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
- find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
- )
- for f in ${files[@]}; do
- cat $f | cut -d " " -f 2-
- done > $out_dir/test.txt
- fi
-
- lang_dir=data/lang_bpe_${vocab_size}
- ./local/prepare_lm_training_data.py \
- --bpe-model $lang_dir/bpe.model \
- --lm-data $out_dir/test.txt \
- --lm-archive $out_dir/lm_data-test.pt
- done
-fi
-
-if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
- log "Stage 14: Sort LM training data"
- # Sort LM training data by sentence length in descending order
- # for ease of training.
- #
- # Sentence length equals to the number of BPE tokens
- # in a sentence.
-
- for vocab_size in ${vocab_sizes[@]}; do
- out_dir=data/lm_training_bpe_${vocab_size}
- mkdir -p $out_dir
- ./local/sort_lm_training_data.py \
- --in-lm-data $out_dir/lm_data.pt \
- --out-lm-data $out_dir/sorted_lm_data.pt \
- --out-statistics $out_dir/statistics.txt
-
- ./local/sort_lm_training_data.py \
- --in-lm-data $out_dir/lm_data-valid.pt \
- --out-lm-data $out_dir/sorted_lm_data-valid.pt \
- --out-statistics $out_dir/statistics-valid.txt
-
- ./local/sort_lm_training_data.py \
- --in-lm-data $out_dir/lm_data-test.pt \
- --out-lm-data $out_dir/sorted_lm_data-test.pt \
- --out-statistics $out_dir/statistics-test.txt
- done
-fi
diff --git a/egs/librispeech/ASR/prepare_lm.sh b/egs/librispeech/ASR/prepare_lm.sh
new file mode 100755
index 000000000..a8eb5ca78
--- /dev/null
+++ b/egs/librispeech/ASR/prepare_lm.sh
@@ -0,0 +1,262 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+# This script generate Ngram LM / NNLM and related files that needed by decoding.
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/lm
+# This directory contains the following files downloaded from
+# http://www.openslr.org/resources/11
+#
+# - 3-gram.pruned.1e-7.arpa.gz
+# - 3-gram.pruned.1e-7.arpa
+# - 4-gram.arpa.gz
+# - 4-gram.arpa
+# - librispeech-vocab.txt
+# - librispeech-lexicon.txt
+# - librispeech-lm-norm.txt.gz
+#
+
+. prepare.sh --stage -1 --stop-stage 6 || exit 1
+
+log "Running prepare_lm.sh"
+
+stage=0
+stop_stage=100
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Prepare BPE based lexicon."
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ # We reuse words.txt from phone based lexicon
+ # so that the two can share G.pt later.
+ cp data/lang_phone/words.txt $lang_dir
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bpe.model
+ fi
+
+ if [ ! -f $lang_dir/L.fst ]; then
+ log "Converting L.pt to L.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L.pt \
+ $lang_dir/L.fst
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.fst ]; then
+ log "Converting L_disambig.pt to L_disambig.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L_disambig.pt \
+ $lang_dir/L_disambig.fst
+ fi
+ done
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare word level G"
+ # We assume you have installed kaldilm, if not, please install
+ # it using: pip install kaldilm
+
+ mkdir -p data/lm
+ if [ ! -f data/lm/G_3_gram.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table="data/lang_phone/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
+ fi
+
+ if [ ! -f data/lm/G_4_gram.fst.txt ]; then
+ # It is used for LM rescoring
+ python3 -m kaldilm \
+ --read-symbol-table="data/lang_phone/words.txt" \
+ --disambig-symbol='#0' \
+ --max-order=4 \
+ $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
+ fi
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/HL.fst ]; then
+ ./local/prepare_lang_fst.py \
+ --lang-dir $lang_dir \
+ --ngram-G ./data/lm/G_3_gram.fst.txt
+ fi
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compile HLG"
+ ./local/compile_hlg.py --lang-dir data/lang_phone
+
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/compile_hlg.py --lang-dir $lang_dir
+
+ # Note If ./local/compile_hlg.py throws OOM,
+ # please switch to the following command
+ #
+ # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir
+ done
+fi
+
+# Compile LG for RNN-T fast_beam_search decoding
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Compile LG"
+ ./local/compile_lg.py --lang-dir data/lang_phone
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/compile_lg.py --lang-dir $lang_dir
+ done
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare token level ngram G"
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+ ./local/convert_transcript_words_to_tokens.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --transcript $lang_dir/transcript_words.txt \
+ --oov "" \
+ > $lang_dir/transcript_tokens.txt
+ fi
+
+ for ngram in 2 3 4 5; do
+ if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order ${ngram} \
+ -text $lang_dir/transcript_tokens.txt \
+ -lm $lang_dir/${ngram}gram.arpa
+ fi
+
+ if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/tokens.txt" \
+ --disambig-symbol='#0' \
+ --max-order=${ngram} \
+ $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
+ fi
+ done
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate NNLM training data"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Processing vocab_size == ${vocab_size}"
+ lang_dir=data/lang_bpe_${vocab_size}
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+
+ ./local/prepare_lm_training_data.py \
+ --bpe-model $lang_dir/bpe.model \
+ --lm-data $dl_dir/lm/librispeech-lm-norm.txt \
+ --lm-archive $out_dir/lm_data.pt
+ done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Generate NNLM validation data"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Processing vocab_size == ${vocab_size}"
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+
+ if [ ! -f $out_dir/valid.txt ]; then
+ files=$(
+ find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > $out_dir/valid.txt
+ fi
+
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/prepare_lm_training_data.py \
+ --bpe-model $lang_dir/bpe.model \
+ --lm-data $out_dir/valid.txt \
+ --lm-archive $out_dir/lm_data-valid.pt
+ done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Generate NNLM test data"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Processing vocab_size == ${vocab_size}"
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+
+ if [ ! -f $out_dir/test.txt ]; then
+ files=$(
+ find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt"
+ find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt"
+ )
+ for f in ${files[@]}; do
+ cat $f | cut -d " " -f 2-
+ done > $out_dir/test.txt
+ fi
+
+ lang_dir=data/lang_bpe_${vocab_size}
+ ./local/prepare_lm_training_data.py \
+ --bpe-model $lang_dir/bpe.model \
+ --lm-data $out_dir/test.txt \
+ --lm-archive $out_dir/lm_data-test.pt
+ done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Sort NNLM training data"
+ # Sort LM training data by sentence length in descending order
+ # for ease of training.
+ #
+ # Sentence length equals to the number of BPE tokens
+ # in a sentence.
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ out_dir=data/lm_training_bpe_${vocab_size}
+ mkdir -p $out_dir
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data.pt \
+ --out-lm-data $out_dir/sorted_lm_data.pt \
+ --out-statistics $out_dir/statistics.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data-valid.pt \
+ --out-lm-data $out_dir/sorted_lm_data-valid.pt \
+ --out-statistics $out_dir/statistics-valid.txt
+
+ ./local/sort_lm_training_data.py \
+ --in-lm-data $out_dir/lm_data-test.pt \
+ --out-lm-data $out_dir/sorted_lm_data-test.pt \
+ --out-statistics $out_dir/statistics-test.txt
+ done
+fi
diff --git a/egs/librispeech/ASR/prepare_mmi.sh b/egs/librispeech/ASR/prepare_mmi.sh
new file mode 100755
index 000000000..d8a6e0caf
--- /dev/null
+++ b/egs/librispeech/ASR/prepare_mmi.sh
@@ -0,0 +1,45 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+
+. prepare.sh --stage -1 --stop-stage 6 || exit 1
+
+log "Running prepare_mmi.sh"
+
+stage=0
+stop_stage=100
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Prepare bigram token-level P for MMI training"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+
+ if [ ! -f $lang_dir/transcript_tokens.txt ]; then
+ ./local/convert_transcript_words_to_tokens.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --transcript $lang_dir/transcript_words.txt \
+ --oov "" \
+ > $lang_dir/transcript_tokens.txt
+ fi
+
+ if [ ! -f $lang_dir/P.arpa ]; then
+ ./shared/make_kn_lm.py \
+ -ngram-order 2 \
+ -text $lang_dir/transcript_tokens.txt \
+ -lm $lang_dir/P.arpa
+ fi
+
+ if [ ! -f $lang_dir/P.fst.txt ]; then
+ python3 -m kaldilm \
+ --read-symbol-table="$lang_dir/tokens.txt" \
+ --disambig-symbol='#0' \
+ --max-order=2 \
+ $lang_dir/P.arpa > $lang_dir/P.fst.txt
+ fi
+ done
+fi
From d9ae8c02a0abdeddc5a4cf9fad72293eda134de3 Mon Sep 17 00:00:00 2001
From: safarisadegh <37085305+safarisadegh@users.noreply.github.com>
Date: Fri, 9 Feb 2024 10:35:01 +0330
Subject: [PATCH 051/123] Update README.md (#1497)
---
icefall/ctc/README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md
index 0096bc096..1e342f6a3 100644
--- a/icefall/ctc/README.md
+++ b/icefall/ctc/README.md
@@ -12,6 +12,6 @@ pip install kaldifst kaldi-decoder
```
to install the dependencies.
-[kaldi-decoder]: https://github.com/i2-fsa/kaldi-decoder
+[kaldi-decoder]: https://github.com/k2-fsa/kaldi-decoder
[kaldifst]: https://github.com/k2-fsa/kaldifst
[k2]: https://github.com/k2-fsa/k2
From 06b356a610ec0bcef8982f011ba2a46cd8ca29b5 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sun, 18 Feb 2024 12:05:38 +0800
Subject: [PATCH 052/123] Update cpu docker images to support torch 2.2.0
(#1499)
---
.github/scripts/docker/Dockerfile | 1 +
.../scripts/docker/generate_build_matrix.py | 21 ++++++++++++-------
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index f6a088af1..ee0099911 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -11,6 +11,7 @@ ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}"
RUN apt-get update -y && \
apt-get install -qq -y \
+ cmake \
ffmpeg \
git \
git-lfs \
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index bdde97647..f0690f8bf 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -6,8 +6,8 @@ import json
def version_gt(a, b):
- a_major, a_minor = a.split(".")[:2]
- b_major, b_minor = b.split(".")[:2]
+ a_major, a_minor = list(map(int, a.split(".")))[:2]
+ b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major:
return True
@@ -18,8 +18,8 @@ def version_gt(a, b):
def version_ge(a, b):
- a_major, a_minor = a.split(".")[:2]
- b_major, b_minor = b.split(".")[:2]
+ a_major, a_minor = list(map(int, a.split(".")))[:2]
+ b_major, b_minor = list(map(int, b.split(".")))[:2]
if a_major > b_major:
return True
@@ -43,11 +43,12 @@ def get_torchaudio_version(torch_version):
def get_matrix():
- k2_version = "1.24.4.dev20231220"
- kaldifeat_version = "1.25.3.dev20231221"
- version = "1.2"
- python_version = ["3.8", "3.9", "3.10", "3.11"]
+ k2_version = "1.24.4.dev20240211"
+ kaldifeat_version = "1.25.4.dev20240210"
+ version = "1.3"
+ python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
+ torch_version += ["2.2.0"]
matrix = []
for p in python_version:
@@ -57,6 +58,10 @@ def get_matrix():
if version_gt(p, "3.10") and not version_gt(t, "2.0"):
continue
+ # only torch>=2.2.0 supports python 3.12
+ if version_gt(p, "3.11") and not version_gt(t, "2.1"):
+ continue
+
matrix.append(
{
"k2-version": k2_version,
From 17688476e5cbdba92c682d3a75e3941b647573a7 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sun, 18 Feb 2024 14:56:04 +0800
Subject: [PATCH 053/123] Provider docker images for torch 2.2.0 (#1501)
---
.github/workflows/build-docker-image.yml | 2 +-
.github/workflows/run-docker-image.yml | 9 ++-
docker/torch1.12.1-cuda11.3.dockerfile | 4 +-
docker/torch1.13.0-cuda11.6.dockerfile | 4 +-
docker/torch1.9.0-cuda10.2.dockerfile | 4 +-
docker/torch2.0.0-cuda11.7.dockerfile | 4 +-
docker/torch2.1.0-cuda11.8.dockerfile | 4 +-
docker/torch2.1.0-cuda12.1.dockerfile | 4 +-
docker/torch2.2.0-cuda11.8.dockerfile | 70 ++++++++++++++++++++++++
docker/torch2.2.0-cuda12.1.dockerfile | 70 ++++++++++++++++++++++++
docs/source/docker/intro.rst | 2 +
11 files changed, 163 insertions(+), 14 deletions(-)
create mode 100644 docker/torch2.2.0-cuda11.8.dockerfile
create mode 100644 docker/torch2.2.0-cuda12.1.dockerfile
diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml
index e5d96dcdf..d5081f7d8 100644
--- a/.github/workflows/build-docker-image.yml
+++ b/.github/workflows/build-docker-image.yml
@@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
- image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
+ image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml
index d048923b6..65ba2cd64 100644
--- a/.github/workflows/run-docker-image.yml
+++ b/.github/workflows/run-docker-image.yml
@@ -14,13 +14,20 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
- image: ["torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
+ image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps:
# refer to https://github.com/actions/checkout
- uses: actions/checkout@v2
with:
fetch-depth: 0
+ - name: Free space
+ shell: bash
+ run: |
+ df -h
+ rm -rf /opt/hostedtoolcache
+ df -h
+
- name: Run the build process with Docker
uses: addnab/docker-run-action@v3
with:
diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile
index deb5715cc..cb885e59e 100644
--- a/docker/torch1.12.1-cuda11.3.dockerfile
+++ b/docker/torch1.12.1-cuda11.3.dockerfile
@@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.7
-ARG K2_VERSION="1.24.4.dev20230725+cuda11.3.torch1.12.1"
-ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.3.torch1.12.1"
+ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1"
ARG TORCHAUDIO_VERSION="0.12.1+cu113"
LABEL authors="Fangjun Kuang "
diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile
index afc6c1b84..e238d87aa 100644
--- a/docker/torch1.13.0-cuda11.6.dockerfile
+++ b/docker/torch1.13.0-cuda11.6.dockerfile
@@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.9
-ARG K2_VERSION="1.24.4.dev20231021+cuda11.6.torch1.13.0"
-ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.6.torch1.13.0"
+ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0"
ARG TORCHAUDIO_VERSION="0.13.0+cu116"
LABEL authors="Fangjun Kuang "
diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile
index 9ff225b54..26d45cafc 100644
--- a/docker/torch1.9.0-cuda10.2.dockerfile
+++ b/docker/torch1.9.0-cuda10.2.dockerfile
@@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.7
-ARG K2_VERSION="1.24.3.dev20230726+cuda10.2.torch1.9.0"
-ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda10.2.torch1.9.0"
+ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0"
ARG TORCHAUDIO_VERSION="0.9.0"
LABEL authors="Fangjun Kuang "
diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile
index db8076560..02906e53b 100644
--- a/docker/torch2.0.0-cuda11.7.dockerfile
+++ b/docker/torch2.0.0-cuda11.7.dockerfile
@@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
-ARG K2_VERSION="1.24.4.dev20231021+cuda11.7.torch2.0.0"
-ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.7.torch2.0.0"
+ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0"
ARG TORCHAUDIO_VERSION="2.0.0+cu117"
LABEL authors="Fangjun Kuang "
diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile
index b006b0d96..c87305922 100644
--- a/docker/torch2.1.0-cuda11.8.dockerfile
+++ b/docker/torch2.1.0-cuda11.8.dockerfile
@@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
-ARG K2_VERSION="1.24.4.dev20231021+cuda11.8.torch2.1.0"
-ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda11.8.torch2.1.0"
+ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu118"
LABEL authors="Fangjun Kuang "
diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile
index 1b078dc22..f4c297678 100644
--- a/docker/torch2.1.0-cuda12.1.dockerfile
+++ b/docker/torch2.1.0-cuda12.1.dockerfile
@@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
-ARG K2_VERSION="1.24.4.dev20231021+cuda12.1.torch2.1.0"
-ARG KALDIFEAT_VERSION="1.25.1.dev20231022+cuda12.1.torch2.1.0"
+ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu121"
LABEL authors="Fangjun Kuang "
diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile
new file mode 100644
index 000000000..c59661c27
--- /dev/null
+++ b/docker/torch2.2.0-cuda11.8.dockerfile
@@ -0,0 +1,70 @@
+FROM pytorch/pytorch:2.2.0-cuda11.8-cudnn8-devel
+
+ENV LC_ALL C.UTF-8
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+# python 3.10
+ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0"
+ARG TORCHAUDIO_VERSION="2.2.0+cu118"
+
+LABEL authors="Fangjun Kuang "
+LABEL k2_version=${K2_VERSION}
+LABEL kaldifeat_version=${KALDIFEAT_VERSION}
+LABEL github_repo="https://github.com/k2-fsa/icefall"
+
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ curl \
+ vim \
+ libssl-dev \
+ autoconf \
+ automake \
+ bzip2 \
+ ca-certificates \
+ ffmpeg \
+ g++ \
+ gfortran \
+ git \
+ libtool \
+ make \
+ patch \
+ sox \
+ subversion \
+ unzip \
+ valgrind \
+ wget \
+ zlib1g-dev \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install dependencies
+RUN pip install --no-cache-dir \
+ torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
+ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
+ git+https://github.com/lhotse-speech/lhotse \
+ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
+ kaldi_native_io \
+ kaldialign \
+ kaldifst \
+ kaldilm \
+ sentencepiece>=0.1.96 \
+ tensorboard \
+ typeguard \
+ dill \
+ onnx \
+ onnxruntime \
+ onnxmltools \
+ multi_quantization \
+ typeguard \
+ numpy \
+ pytest \
+ graphviz
+
+RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
+ cd /workspace/icefall && \
+ pip install --no-cache-dir -r requirements.txt
+
+ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
+
+WORKDIR /workspace/icefall
diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile
new file mode 100644
index 000000000..2c484efd5
--- /dev/null
+++ b/docker/torch2.2.0-cuda12.1.dockerfile
@@ -0,0 +1,70 @@
+FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel
+
+ENV LC_ALL C.UTF-8
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+# python 3.10
+ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0"
+ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0"
+ARG TORCHAUDIO_VERSION="2.2.0+cu121"
+
+LABEL authors="Fangjun Kuang "
+LABEL k2_version=${K2_VERSION}
+LABEL kaldifeat_version=${KALDIFEAT_VERSION}
+LABEL github_repo="https://github.com/k2-fsa/icefall"
+
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ curl \
+ vim \
+ libssl-dev \
+ autoconf \
+ automake \
+ bzip2 \
+ ca-certificates \
+ ffmpeg \
+ g++ \
+ gfortran \
+ git \
+ libtool \
+ make \
+ patch \
+ sox \
+ subversion \
+ unzip \
+ valgrind \
+ wget \
+ zlib1g-dev \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install dependencies
+RUN pip install --no-cache-dir \
+ torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
+ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
+ git+https://github.com/lhotse-speech/lhotse \
+ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
+ kaldi_native_io \
+ kaldialign \
+ kaldifst \
+ kaldilm \
+ sentencepiece>=0.1.96 \
+ tensorboard \
+ typeguard \
+ dill \
+ onnx \
+ onnxruntime \
+ onnxmltools \
+ multi_quantization \
+ typeguard \
+ numpy \
+ pytest \
+ graphviz
+
+RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
+ cd /workspace/icefall && \
+ pip install --no-cache-dir -r requirements.txt
+
+ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
+
+WORKDIR /workspace/icefall
diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst
index cbd300d9b..149970eff 100644
--- a/docs/source/docker/intro.rst
+++ b/docs/source/docker/intro.rst
@@ -34,6 +34,8 @@ which will give you something like below:
.. code-block:: bash
+ "torch2.2.0-cuda12.1"
+ "torch2.2.0-cuda11.8"
"torch2.1.0-cuda12.1"
"torch2.1.0-cuda11.8"
"torch2.0.0-cuda11.7"
From 7eb360d0d5f3eb03292d3ff4596a8d50c9765888 Mon Sep 17 00:00:00 2001
From: Fangjun Kuang
Date: Sun, 18 Feb 2024 20:32:40 +0800
Subject: [PATCH 054/123] Fix cpu docker images for torch 2.2.0 (#1502)
---
.github/scripts/docker/generate_build_matrix.py | 4 ++--
.github/workflows/yesno.yml | 3 +++
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index f0690f8bf..425afac2b 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -43,8 +43,8 @@ def get_torchaudio_version(torch_version):
def get_matrix():
- k2_version = "1.24.4.dev20240211"
- kaldifeat_version = "1.25.4.dev20240210"
+ k2_version = "1.24.4.dev20240218"
+ kaldifeat_version = "1.25.4.dev20240218"
version = "1.3"
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml
index 182300dfa..de822b33f 100644
--- a/.github/workflows/yesno.yml
+++ b/.github/workflows/yesno.yml
@@ -59,4 +59,7 @@ jobs:
cd /icefall
git config --global --add safe.directory /icefall
+ python3 -m torch.utils.collect_env
+ python3 -m k2.version
+
.github/scripts/yesno/ASR/run.sh
From db4d66c0e39a06464f5c316c727ed76babeb10eb Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Mon, 19 Feb 2024 16:13:09 +0800
Subject: [PATCH 055/123] Fixed softlink for `ljspeech` recipe (#1503)
---
egs/ljspeech/TTS/shared | 1 +
egs/ljspeech/TTS/shared/parse_options.sh | 1 -
2 files changed, 1 insertion(+), 1 deletion(-)
create mode 120000 egs/ljspeech/TTS/shared
delete mode 120000 egs/ljspeech/TTS/shared/parse_options.sh
diff --git a/egs/ljspeech/TTS/shared b/egs/ljspeech/TTS/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/ljspeech/TTS/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh
deleted file mode 120000
index e4665e7de..000000000
--- a/egs/ljspeech/TTS/shared/parse_options.sh
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/shared/parse_options.sh
\ No newline at end of file
From b3e2044068001a24bc9293f5b7063377173631d3 Mon Sep 17 00:00:00 2001
From: Zengwei Yao
Date: Mon, 19 Feb 2024 19:33:32 +0800
Subject: [PATCH 056/123] minor fix of vits/tokenizer.py (#1504)
* minor fix of vits/tokenizer.py
---
egs/ljspeech/TTS/vits/tokenizer.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
index 70f1240b4..b0afc6a04 100644
--- a/egs/ljspeech/TTS/vits/tokenizer.py
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -74,7 +74,7 @@ class Tokenizer(object):
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
- token_ids_list.append(token_ids)
+ token_ids_list.append(token_ids)
return token_ids_list
@@ -103,6 +103,7 @@ class Tokenizer(object):
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
- token_ids_list.append(token_ids)
+
+ token_ids_list.append(token_ids)
return token_ids_list
From e59fa38e86bd05241daa4217d66eaa0e36825547 Mon Sep 17 00:00:00 2001
From: Karel Vesely
Date: Tue, 20 Feb 2024 03:40:15 +0100
Subject: [PATCH 057/123] docs: minor fixes of LM rescoring texts (#1498)
---
.../decoding-with-langugage-models/LODR.rst | 6 ++---
.../shallow-fusion.rst | 24 +++++++++----------
2 files changed, 15 insertions(+), 15 deletions(-)
diff --git a/docs/source/decoding-with-langugage-models/LODR.rst b/docs/source/decoding-with-langugage-models/LODR.rst
index b6b6e8cbb..d4b6f7065 100644
--- a/docs/source/decoding-with-langugage-models/LODR.rst
+++ b/docs/source/decoding-with-langugage-models/LODR.rst
@@ -30,7 +30,7 @@ of langugae model integration.
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here `_
to address the language information mismatch between the training
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
-are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
+are acoustically similar, DR derives the following formula for decoding with Bayes' theorem:
.. math::
@@ -41,7 +41,7 @@ are acoustically similar, DR derives the following formular for decoding with Ba
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
-Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
+Here, the source domain LM is trained on the training corpus. The only difference in the above formula compared to
shallow fusion is the subtraction of the source domain LM.
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
@@ -58,7 +58,7 @@ during decoding for transducer model:
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Compared to DR,
the only difference lies in the choice of source domain LM. According to the original `paper `_,
-LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
+LODR achieves similar performance compared to DR in both intra-domain and cross-domain settings.
As a bi-gram is much faster to evaluate, LODR is usually much faster.
Now, we will show you how to use LODR in ``icefall``.
diff --git a/docs/source/decoding-with-langugage-models/shallow-fusion.rst b/docs/source/decoding-with-langugage-models/shallow-fusion.rst
index 684fefeb4..8b2586730 100644
--- a/docs/source/decoding-with-langugage-models/shallow-fusion.rst
+++ b/docs/source/decoding-with-langugage-models/shallow-fusion.rst
@@ -9,9 +9,9 @@ to improve the word-error-rate of a transducer model.
.. note::
- This tutorial is based on the recipe
+ This tutorial is based on the recipe
`pruned_transducer_stateless7_streaming `_,
- which is a streaming transducer model trained on `LibriSpeech`_.
+ which is a streaming transducer model trained on `LibriSpeech`_.
However, you can easily apply shallow fusion to other recipes.
If you encounter any problems, please open an issue here `icefall `_.
@@ -69,11 +69,11 @@ Training a language model usually takes a long time, we can download a pre-train
.. code-block:: bash
$ # download the external LM
- $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
+ $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
$ # create a symbolic link so that the checkpoint can be loaded
$ pushd icefall-librispeech-rnn-lm/exp
$ git lfs pull --include "pretrained.pt"
- $ ln -s pretrained.pt epoch-99.pt
+ $ ln -s pretrained.pt epoch-99.pt
$ popd
.. note::
@@ -85,7 +85,7 @@ Training a language model usually takes a long time, we can download a pre-train
To use shallow fusion for decoding, we can execute the following command:
.. code-block:: bash
-
+
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
$ lm_dir=./icefall-librispeech-rnn-lm/exp
$ lm_scale=0.29
@@ -133,16 +133,16 @@ The decoding result obtained with the above command are shown below.
$ For test-other, WER of different settings are:
$ beam_size_4 7.08 best for test-other
-The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
+The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
A few parameters can be tuned to further boost the performance of shallow fusion:
-- ``--lm-scale``
+- ``--lm-scale``
- Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
- the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
+ Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
+ the LM score might be dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
+
+- ``--beam-size``
-- ``--beam-size``
-
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
Here, we also show how `--beam-size` effect the WER and decoding time:
@@ -176,4 +176,4 @@ As we see, a larger beam size during shallow fusion improves the WER, but is als
-
+
From 027302c902ce9ab44754d42a56cf1eba9a075be9 Mon Sep 17 00:00:00 2001
From: zr_jin
Date: Tue, 20 Feb 2024 14:38:51 +0800
Subject: [PATCH 058/123] minor fix for param. names (#1495)
---
icefall/lm_wrapper.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py
index 5e2783a47..26839c61c 100644
--- a/icefall/lm_wrapper.py
+++ b/icefall/lm_wrapper.py
@@ -159,7 +159,7 @@ class LmScorer(torch.nn.Module):
"""
if lm_type == "rnn":
model = RnnLmModel(
- vocab_size=params.vocab_size,
+ vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
@@ -183,7 +183,7 @@ class LmScorer(torch.nn.Module):
elif lm_type == "transformer":
model = TransformerLM(
- vocab_size=params.vocab_size,
+ vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward,
From c19b4147789f306efed754dd0ed8f651017a7484 Mon Sep 17 00:00:00 2001
From: Wei Kang
Date: Wed, 21 Feb 2024 08:04:16 +0800
Subject: [PATCH 059/123] Update docker (adding pypinyin (#1513)
Update docker (adding pypinyin)
---
.github/scripts/docker/Dockerfile | 1 +
.github/scripts/docker/generate_build_matrix.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index ee0099911..4adb7ab5c 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -51,6 +51,7 @@ RUN pip install --no-cache-dir \
onnxruntime \
pytest \
sentencepiece>=0.1.96 \
+ pypinyin==0.50.0 \
six \
tensorboard \
typeguard
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index 425afac2b..ed01bd740 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version):
def get_matrix():
k2_version = "1.24.4.dev20240218"
kaldifeat_version = "1.25.4.dev20240218"
- version = "1.3"
+ version = "1.4"
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0"]
From 13daf73468da70f5db49c928969fce6c6edc041a Mon Sep 17 00:00:00 2001
From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com>
Date: Wed, 21 Feb 2024 18:06:27 +0800
Subject: [PATCH 060/123] docs for finetune zipformer (#1509)
---
.../from_supervised/finetune_zipformer.rst | 140 ++++++++++++++++++
docs/source/recipes/Finetune/index.rst | 15 ++
docs/source/recipes/index.rst | 1 +
3 files changed, 156 insertions(+)
create mode 100644 docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst
create mode 100644 docs/source/recipes/Finetune/index.rst
diff --git a/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst
new file mode 100644
index 000000000..7ca4eb811
--- /dev/null
+++ b/docs/source/recipes/Finetune/from_supervised/finetune_zipformer.rst
@@ -0,0 +1,140 @@
+Finetune from a supervised pre-trained Zipformer model
+======================================================
+
+This tutorial shows you how to fine-tune a supervised pre-trained **Zipformer**
+transducer model on a new dataset.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe
+
+
+For illustration purpose, we fine-tune the Zipformer transducer model
+pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
+own data for fine-tuning if you create a manifest for your new dataset.
+
+Data preparation
+----------------
+
+Please follow the instructions in the `GigaSpeech recipe `_
+to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
+
+
+Model preparation
+-----------------
+
+We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
+checkpoint of the model can be downloaded via the following command:
+
+.. code-block:: bash
+
+ $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+ $ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
+ $ git lfs pull --include "pretrained.pt"
+ $ ln -s pretrained.pt epoch-99.pt
+ $ cd ../data/lang_bpe_500
+ $ git lfs pull --include bpe.model
+ $ cd ../../..
+
+Before fine-tuning, let's test the model's WER on the new domain. The following command performs
+decoding on the GigaSpeech test sets:
+
+.. code-block:: bash
+
+ ./zipformer/decode_gigaspeech.py \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
+ --use-averaged-model 0 \
+ --max-duration 1000 \
+ --decoding-method greedy_search
+
+You should see the following numbers:
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 20.06 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 19.27 best for test
+
+
+Fine-tune
+---------
+
+Since LibriSpeech and GigaSpeech are both English dataset, we can initialize the whole
+Zipformer model with the checkpoint downloaded in the previous step (otherwise we should consider
+initializing the stateless decoder and joiner from scratch due to the mismatch of the output
+vocabulary). The following command starts a fine-tuning experiment:
+
+.. code-block:: bash
+
+ $ use_mux=0
+ $ do_finetune=1
+
+ $ ./zipformer/finetune.py \
+ --world-size 2 \
+ --num-epochs 20 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
+ --use-fp16 1 \
+ --base-lr 0.0045 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --do-finetune $do_finetune \
+ --use-mux $use_mux \
+ --master-port 13024 \
+ --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
+ --max-duration 1000
+
+The following arguments are related to fine-tuning:
+
+- ``--base-lr``
+ The learning rate used for fine-tuning. We suggest to set a **small** learning rate for fine-tuning,
+ otherwise the model may forget the initialization very quickly. A reasonable value should be around
+ 1/10 of the original lr, i.e 0.0045.
+
+- ``--do-finetune``
+ If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
+ **Note that if you want to resume your fine-tuning experiment from certain epochs, you
+ need to set this to False.**
+
+- ``--finetune-ckpt``
+ The path to the pre-trained checkpoint (used for initialization).
+
+- ``--use-mux``
+ If True, mix the fine-tune data with the original training data by using `CutSet.mux `_
+ This helps maintain the model's performance on the original domain if the original training
+ is available. **If you don't have the original training data, please set it to False.**
+
+After fine-tuning, let's test the WERs. You can do this via the following command:
+
+.. code-block:: bash
+
+ $ use_mux=0
+ $ do_finetune=1
+ $ ./zipformer/decode_gigaspeech.py \
+ --epoch 20 \
+ --avg 10 \
+ --exp-dir zipformer/exp_giga_finetune${do_finetune}_mux${use_mux} \
+ --use-averaged-model 1 \
+ --max-duration 1000 \
+ --decoding-method greedy_search
+
+You should see numbers similar to the ones below:
+
+.. code-block:: text
+
+ For dev, WER of different settings are:
+ greedy_search 13.47 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 13.66 best for test
+
+Compared to the original checkpoint, the fine-tuned model achieves much lower WERs
+on the GigaSpeech test sets.
diff --git a/docs/source/recipes/Finetune/index.rst b/docs/source/recipes/Finetune/index.rst
new file mode 100644
index 000000000..e62b8980f
--- /dev/null
+++ b/docs/source/recipes/Finetune/index.rst
@@ -0,0 +1,15 @@
+Fine-tune a pre-trained model
+=============================
+
+After pre-training on public available datasets, the ASR model is already capable of
+performing general speech recognition with relatively high accuracy. However, the accuracy
+could be still low on certain domains that are quite different from the original training
+set. In this case, we can fine-tune the model with a small amount of additional labelled
+data to improve the performance on new domains.
+
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Table of Contents
+
+ from_supervised/finetune_zipformer
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 8df61f0d0..52795d452 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -17,3 +17,4 @@ We may add recipes for other tasks as well in the future.
Streaming-ASR/index
RNN-LM/index
TTS/index
+ Finetune/index
From aac7df064a6d1529f3bf4acccc6c550bd260b7b3 Mon Sep 17 00:00:00 2001
From: Wei Kang
Date: Thu, 22 Feb 2024 15:31:20 +0800
Subject: [PATCH 061/123] Recipes for open vocabulary keyword spotting (#1428)
* English recipe on gigaspeech; Chinese recipe on wenetspeech
---
...ev_test.py => compute_fbank_gigaspeech.py} | 14 +-
.../local/compute_fbank_gigaspeech_splits.py | 16 +-
.../ASR/local/preprocess_gigaspeech.py | 38 +-
egs/gigaspeech/ASR/prepare.sh | 61 +-
.../ASR/zipformer/asr_datamodule.py | 5 +-
egs/gigaspeech/ASR/zipformer/train.py | 21 +-
egs/gigaspeech/KWS/RESULTS.md | 49 +
egs/gigaspeech/KWS/prepare.sh | 85 +
egs/gigaspeech/KWS/run.sh | 197 +++
egs/gigaspeech/KWS/shared | 1 +
.../KWS/zipformer/asr_datamodule.py | 477 ++++++
egs/gigaspeech/KWS/zipformer/beam_search.py | 1 +
egs/gigaspeech/KWS/zipformer/decode-asr.py | 1066 +++++++++++++
egs/gigaspeech/KWS/zipformer/decode.py | 689 ++++++++
egs/gigaspeech/KWS/zipformer/decoder.py | 1 +
.../KWS/zipformer/encoder_interface.py | 1 +
.../KWS/zipformer/export-onnx-streaming.py | 1 +
egs/gigaspeech/KWS/zipformer/export.py | 1 +
egs/gigaspeech/KWS/zipformer/finetune.py | 644 ++++++++
.../KWS/zipformer/gigaspeech_scoring.py | 1 +
egs/gigaspeech/KWS/zipformer/joiner.py | 1 +
egs/gigaspeech/KWS/zipformer/model.py | 1 +
egs/gigaspeech/KWS/zipformer/optim.py | 1 +
egs/gigaspeech/KWS/zipformer/scaling.py | 1 +
egs/gigaspeech/KWS/zipformer/subsampling.py | 1 +
egs/gigaspeech/KWS/zipformer/train.py | 1367 ++++++++++++++++
egs/gigaspeech/KWS/zipformer/zipformer.py | 1 +
.../beam_search.py | 241 +++
.../ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 +-
.../ASR/pruned_transducer_stateless5/train.py | 8 +-
.../local/prepare_dataset_from_kaldi_dir.py | 142 ++
egs/wenetspeech/ASR/local/prepare_pinyin.py | 275 ++++
egs/wenetspeech/ASR/prepare.sh | 16 +-
egs/wenetspeech/KWS/RESULTS.md | 58 +
egs/wenetspeech/KWS/prepare.sh | 90 ++
egs/wenetspeech/KWS/run.sh | 201 +++
egs/wenetspeech/KWS/shared | 1 +
.../KWS/zipformer/asr_datamodule.py | 459 ++++++
egs/wenetspeech/KWS/zipformer/beam_search.py | 1 +
egs/wenetspeech/KWS/zipformer/decode-asr.py | 767 +++++++++
egs/wenetspeech/KWS/zipformer/decode.py | 737 +++++++++
egs/wenetspeech/KWS/zipformer/decoder.py | 1 +
.../KWS/zipformer/encoder_interface.py | 1 +
.../KWS/zipformer/export-onnx-streaming.py | 1 +
egs/wenetspeech/KWS/zipformer/export.py | 1 +
egs/wenetspeech/KWS/zipformer/finetune.py | 814 ++++++++++
egs/wenetspeech/KWS/zipformer/joiner.py | 1 +
egs/wenetspeech/KWS/zipformer/model.py | 1 +
egs/wenetspeech/KWS/zipformer/optim.py | 1 +
egs/wenetspeech/KWS/zipformer/scaling.py | 1 +
.../KWS/zipformer/scaling_converter.py | 1 +
egs/wenetspeech/KWS/zipformer/subsampling.py | 1 +
egs/wenetspeech/KWS/zipformer/train.py | 1401 +++++++++++++++++
egs/wenetspeech/KWS/zipformer/zipformer.py | 1 +
icefall/char_graph_compiler.py | 35 +-
icefall/context_graph.py | 218 ++-
icefall/utils.py | 96 ++
57 files changed, 10203 insertions(+), 120 deletions(-)
rename egs/gigaspeech/ASR/local/{compute_fbank_gigaspeech_dev_test.py => compute_fbank_gigaspeech.py} (87%)
create mode 100644 egs/gigaspeech/KWS/RESULTS.md
create mode 100755 egs/gigaspeech/KWS/prepare.sh
create mode 100755 egs/gigaspeech/KWS/run.sh
create mode 120000 egs/gigaspeech/KWS/shared
create mode 100644 egs/gigaspeech/KWS/zipformer/asr_datamodule.py
create mode 120000 egs/gigaspeech/KWS/zipformer/beam_search.py
create mode 100755 egs/gigaspeech/KWS/zipformer/decode-asr.py
create mode 100755 egs/gigaspeech/KWS/zipformer/decode.py
create mode 120000 egs/gigaspeech/KWS/zipformer/decoder.py
create mode 120000 egs/gigaspeech/KWS/zipformer/encoder_interface.py
create mode 120000 egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py
create mode 120000 egs/gigaspeech/KWS/zipformer/export.py
create mode 100755 egs/gigaspeech/KWS/zipformer/finetune.py
create mode 120000 egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
create mode 120000 egs/gigaspeech/KWS/zipformer/joiner.py
create mode 120000 egs/gigaspeech/KWS/zipformer/model.py
create mode 120000 egs/gigaspeech/KWS/zipformer/optim.py
create mode 120000 egs/gigaspeech/KWS/zipformer/scaling.py
create mode 120000 egs/gigaspeech/KWS/zipformer/subsampling.py
create mode 100755 egs/gigaspeech/KWS/zipformer/train.py
create mode 120000 egs/gigaspeech/KWS/zipformer/zipformer.py
create mode 100644 egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py
create mode 100755 egs/wenetspeech/ASR/local/prepare_pinyin.py
create mode 100644 egs/wenetspeech/KWS/RESULTS.md
create mode 100755 egs/wenetspeech/KWS/prepare.sh
create mode 100755 egs/wenetspeech/KWS/run.sh
create mode 120000 egs/wenetspeech/KWS/shared
create mode 100644 egs/wenetspeech/KWS/zipformer/asr_datamodule.py
create mode 120000 egs/wenetspeech/KWS/zipformer/beam_search.py
create mode 100755 egs/wenetspeech/KWS/zipformer/decode-asr.py
create mode 100755 egs/wenetspeech/KWS/zipformer/decode.py
create mode 120000 egs/wenetspeech/KWS/zipformer/decoder.py
create mode 120000 egs/wenetspeech/KWS/zipformer/encoder_interface.py
create mode 120000 egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py
create mode 120000 egs/wenetspeech/KWS/zipformer/export.py
create mode 100755 egs/wenetspeech/KWS/zipformer/finetune.py
create mode 120000 egs/wenetspeech/KWS/zipformer/joiner.py
create mode 120000 egs/wenetspeech/KWS/zipformer/model.py
create mode 120000 egs/wenetspeech/KWS/zipformer/optim.py
create mode 120000 egs/wenetspeech/KWS/zipformer/scaling.py
create mode 120000 egs/wenetspeech/KWS/zipformer/scaling_converter.py
create mode 120000 egs/wenetspeech/KWS/zipformer/subsampling.py
create mode 100755 egs/wenetspeech/KWS/zipformer/train.py
create mode 120000 egs/wenetspeech/KWS/zipformer/zipformer.py
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py
similarity index 87%
rename from egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py
rename to egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py
index 07beeb1f0..9e0df0989 100755
--- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py
@@ -30,15 +30,15 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
-def compute_fbank_gigaspeech_dev_test():
+def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
- batch_duration = 600
+ batch_duration = 1000
- subsets = ("DEV", "TEST")
+ subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
@@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}")
for partition in subsets:
- cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
+ cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
- raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
+ raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
- storage_path=f"{in_out_dir}/feats_{partition}",
+ storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
@@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
- compute_fbank_gigaspeech_dev_test()
+ compute_fbank_gigaspeech()
if __name__ == "__main__":
diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
index 176eb8a84..51cd59078 100755
--- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
+++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py
@@ -51,14 +51,6 @@ def get_parser():
"Determines batch size dynamically.",
)
- parser.add_argument(
- "--subset",
- type=str,
- default="XL",
- choices=["XL", "L", "M", "S", "XS"],
- help="Which subset to work with",
- )
-
parser.add_argument(
"--num-splits",
type=int,
@@ -84,7 +76,7 @@ def get_parser():
def compute_fbank_gigaspeech_splits(args):
num_splits = args.num_splits
- output_dir = f"data/fbank/{args.subset}_split"
+ output_dir = f"data/fbank/XL_split"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
@@ -107,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
- cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz"
+ cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
- raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz"
+ raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@@ -121,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
- storage_path=f"{output_dir}/feats_{args.subset}_{idx}",
+ storage_path=f"{output_dir}/gigaspeech_feats_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,
diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
index 31abe7fff..b6603f80d 100755
--- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
+++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py
@@ -16,17 +16,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
import logging
import re
from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
+from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=False,
+ help="Whether to use speed perturbation.",
+ )
+
+ return parser.parse_args()
+
+
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
@@ -42,7 +56,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None
-def preprocess_giga_speech():
+def preprocess_giga_speech(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
@@ -51,6 +65,10 @@ def preprocess_giga_speech():
"DEV",
"TEST",
"XL",
+ "L",
+ "M",
+ "S",
+ "XS",
)
logging.info("Loading manifest (may take 4 minutes)")
@@ -71,7 +89,7 @@ def preprocess_giga_speech():
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
- raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
+ raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
@@ -94,11 +112,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
- logging.info(
- f"Speed perturb for {partition} with factors 0.9 and 1.1 "
- "(Perturbing may take 8 minutes and saving may take 20 minutes)"
- )
- cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ if args.perturb_speed:
+ logging.info(
+ f"Speed perturb for {partition} with factors 0.9 and 1.1 "
+ "(Perturbing may take 8 minutes and saving may take 20 minutes)"
+ )
+ cut_set = (
+ cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+ )
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
@@ -107,7 +128,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
- preprocess_giga_speech()
+ args = get_args()
+ preprocess_giga_speech(args)
if __name__ == "__main__":
diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh
index a23b708d7..5e54b669a 100755
--- a/egs/gigaspeech/ASR/prepare.sh
+++ b/egs/gigaspeech/ASR/prepare.sh
@@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1;
fi
# Download XL, DEV and TEST sets by default.
- lhotse download gigaspeech --subset auto --host tsinghua \
+ lhotse download gigaspeech --subset XL \
+ --subset L \
+ --subset M \
+ --subset S \
+ --subset XS \
+ --subset DEV \
+ --subset TEST \
+ --host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
@@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
- lhotse prepare gigaspeech --subset auto -j $nj \
+ lhotse prepare gigaspeech --subset XL \
+ --subset L \
+ --subset M \
+ --subset S \
+ --subset XS \
+ --subset DEV \
+ --subset TEST \
+ -j $nj \
$dl_dir/GigaSpeech data/manifests
fi
@@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
- log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
- python3 ./local/compute_fbank_gigaspeech_dev_test.py
+ log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
+ python3 ./local/compute_fbank_gigaspeech.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
- log "Stage 9: Prepare phone based lang"
+ log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone
mkdir -p $lang_dir
-
- (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
- cat - $dl_dir/lm/lexicon.txt |
- sort | uniq > $lang_dir/lexicon.txt
-
- if [ ! -f $lang_dir/L_disambig.pt ]; then
- ./local/prepare_lang.py --lang-dir $lang_dir
- fi
-
if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \
@@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
- log "Stage 10: Prepare BPE based lang"
+ log "Stage 10: Prepare phone based lang"
+ lang_dir=data/lang_phone
+ mkdir -p $lang_dir
+
+ (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) |
+ cat - $dl_dir/lm/lexicon.txt |
+ sort | uniq > $lang_dir/lexicon.txt
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang.py --lang-dir $lang_dir
+ fi
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done
fi
-if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
- log "Stage 11: Prepare bigram P"
+if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
+ log "Stage 12: Prepare bigram P"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done
fi
-if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
- log "Stage 12: Prepare G"
+if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
+ log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
@@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi
fi
-if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
- log "Stage 13: Compile HLG"
+if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
+ log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
index 850ab7c10..0501461cd 100644
--- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
+++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py
@@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
- default=30,
+ default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@@ -368,6 +368,8 @@ class GigaSpeechAsrDataModule:
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
shuffle=False,
)
logging.info("About to create dev dataloader")
@@ -417,6 +419,7 @@ class GigaSpeechAsrDataModule:
logging.info(
f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
)
+
cuts_train = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
)
diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py
index d93cc221c..c5335562c 100755
--- a/egs/gigaspeech/ASR/zipformer/train.py
+++ b/egs/gigaspeech/ASR/zipformer/train.py
@@ -416,6 +416,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)
+ parser.add_argument(
+ "--scan-for-oom-batches",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to scan for oom batches before training, this is helpful for
+ finding the suitable max_duration, you only need to run it once.
+ Caution: a little time consuming.
+ """,
+ )
+
parser.add_argument(
"--inf-check",
type=str2bool,
@@ -1171,9 +1182,16 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
+ def remove_short_utt(c: Cut):
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ return T > 0
+
gigaspeech = GigaSpeechAsrDataModule(args)
train_cuts = gigaspeech.train_cuts()
+ train_cuts = train_cuts.filter(remove_short_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@@ -1187,9 +1205,10 @@ def run(rank, world_size, args):
)
valid_cuts = gigaspeech.dev_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
- if not params.print_diagnostics:
+ if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
diff --git a/egs/gigaspeech/KWS/RESULTS.md b/egs/gigaspeech/KWS/RESULTS.md
new file mode 100644
index 000000000..992240e14
--- /dev/null
+++ b/egs/gigaspeech/KWS/RESULTS.md
@@ -0,0 +1,49 @@
+# Results
+
+## zipformer transducer model
+
+This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details.
+
+The modeling units are 500 BPEs trained on gigaspeech transcripts.
+
+The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios).
+
+We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands.
+
+The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz).
+
+Here is the results of a small test set which has 20 commands, we list the results of every commands, for
+each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other
+for the finetune model finetuned on commands dataset.
+
+Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours |
+-- | -- | -- | -- | --| -- | -- | -- | --
+ | original | finetune | original | finetune | original | finetune | original | finetune
+All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6
+Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225
+Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025
+Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05
+Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0
+Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0
+Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1
+Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05
+Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0
+Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0
+Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025
+Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025
+Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0
+Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0
+Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
+Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
+Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0
+Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0
+Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0
+Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075
+Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025
+
+This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here.
+
+Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours
+-- | -- | -- | -- | -- | -- | -- | -- | --
+ | original | finetune | original | finetune | original | finetune | original | finetune
+All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3
diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh
new file mode 100755
index 000000000..0b098190d
--- /dev/null
+++ b/egs/gigaspeech/KWS/prepare.sh
@@ -0,0 +1,85 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=15
+stage=0
+stop_stage=100
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Prepare gigaspeech dataset."
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.gigaspeech.done ]; then
+ pushd ../ASR
+ ./prepare.sh --stage 0 --stop-stage 9
+ ./prepare.sh --stage 11 --stop-stage 11
+ popd
+ pushd data/fbank
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
+ ln -svf $(realpath ../ASR/data/fbank/XL_split) .
+ ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
+ ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
+ popd
+ pushd data
+ ln -svf $(realpath ../ASR/data/lang_bpe_500) .
+ popd
+ touch data/fbank/.gigaspeech.done
+ else
+ log "Gigaspeech dataset already exists, skipping."
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare open commands dataset."
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
+ pushd data
+ git clone https://github.com/pkufool/open-commands.git
+ ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
+ ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
+ pushd open-commands
+ ./script/prepare.sh --stage 2 --stop-stage 2
+ ./script/prepare.sh --stage 6 --stop-stage 6
+ popd
+ popd
+ pushd data/fbank
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
+ ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
+ popd
+ touch data/fbank/.fluent_speech_commands.done
+ else
+ log "Fluent speech commands dataset already exists, skipping."
+ fi
+fi
diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh
new file mode 100755
index 000000000..ea04c7c9b
--- /dev/null
+++ b/egs/gigaspeech/KWS/run.sh
@@ -0,0 +1,197 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+export PYTHONPATH=../../../:$PYTHONPATH
+
+stage=0
+stop_stage=100
+
+. shared/parse_options.sh || exit 1
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Train a model."
+ if [ ! -e data/fbank/.gigaspeech.done ]; then
+ log "You need to run the prepare.sh first."
+ exit -1
+ fi
+
+ python ./zipformer/train.py \
+ --world-size 4 \
+ --exp-dir zipformer/exp \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --num-epochs 12 \
+ --lr-epochs 1.5 \
+ --use-fp16 1 \
+ --start-epoch 1 \
+ --subset XL \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --causal 1 \
+ --max-duration 1000
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Decode the model."
+ for t in small, large; do
+ python ./zipformer/decode.py \
+ --epoch 12 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --test-set $t \
+ --keywords-score 1.0 \
+ --keywords-threshold 0.35 \
+ --keywords-file ./data/commands_${t}.txt \
+ --max-duration 3000
+ done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Export the model."
+
+ python ./zipformer/export.py \
+ --epoch 12 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128
+
+ python ./zipformer/export_onnx_streaming.py \
+ --exp-dir zipformer/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 12 \
+ --avg 2 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --causal 1
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 2: Finetune the model"
+
+ # The following configuration of lr schedule should work well
+ # You may also tune the following parameters to adjust learning rate schedule
+ base_lr=0.0005
+ lr_epochs=100
+ lr_batches=100000
+
+ # We recommend to start from an averaged model
+ finetune_ckpt=zipformer/exp/pretrained.pt
+
+ ./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 10 \
+ --start-epoch 1 \
+ --exp-dir zipformer/exp_finetune \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --use-fp16 1 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --causal 1 \
+ --base-lr $base_lr \
+ --lr-epochs $lr_epochs \
+ --lr-batches $lr_batches \
+ --finetune-ckpt $finetune_ckpt \
+ --max-duration 1500
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 1: Decode the finetuned model."
+ for t in small, large; do
+ python ./zipformer/decode.py \
+ --epoch 10 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp_finetune \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --test-set $t \
+ --keywords-score 1.0 \
+ --keywords-threshold 0.35 \
+ --keywords-file ./data/commands_${t}.txt \
+ --max-duration 3000
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 2: Export the finetuned model."
+
+ python ./zipformer/export.py \
+ --epoch 10 \
+ --avg 2 \
+ --exp-dir ./zipformer/exp_finetune \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 64 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128
+
+ python ./zipformer/export_onnx_streaming.py \
+ --exp-dir zipformer/exp_finetune \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 10 \
+ --avg 2 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoder-dim 320 \
+ --joiner-dim 320 \
+ --num-encoder-layers 1,1,1,1,1,1 \
+ --feedforward-dim 192,192,192,192,192,192 \
+ --encoder-dim 128,128,128,128,128,128 \
+ --encoder-unmasked-dim 128,128,128,128,128,128 \
+ --causal 1
+fi
diff --git a/egs/gigaspeech/KWS/shared b/egs/gigaspeech/KWS/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/gigaspeech/KWS/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py
new file mode 100644
index 000000000..ccc602404
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py
@@ -0,0 +1,477 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2024 Xiaomi Corporation (Author: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import glob
+import inspect
+import logging
+import re
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import lhotse
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class GigaSpeechAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ # GigaSpeech specific arguments
+ group.add_argument(
+ "--subset",
+ type=str,
+ default="XL",
+ help="Select the GigaSpeech subset (XS|S|M|L|XL)",
+ )
+ group.add_argument(
+ "--small-dev",
+ type=str2bool,
+ default=False,
+ help="Should we use only 1000 utterances for dev (speeds up training)",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info(f"About to get train {self.args.subset} cuts")
+ if self.args.subset == "XL":
+ filenames = glob.glob(
+ f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz"
+ )
+ pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz")
+ idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames)
+ idx_filenames = sorted(idx_filenames, key=lambda x: x[0])
+ sorted_filenames = [f[1] for f in idx_filenames]
+ logging.info(
+ f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode"
+ )
+
+ cuts_train = lhotse.combine(
+ lhotse.load_manifest_lazy(p) for p in sorted_filenames
+ )
+ else:
+ path = (
+ self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
+ )
+ cuts_train = CutSet.from_jsonl_lazy(path)
+ return cuts_train
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ cuts_valid = load_manifest_lazy(
+ self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
+ )
+ if self.args.small_dev:
+ return cuts_valid.subset(first=1000)
+ else:
+ return cuts_valid
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_train_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_valid_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands valid cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_test_small_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands small test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
+ )
+
+ @lru_cache()
+ def fsc_test_large_cuts(self) -> CutSet:
+ logging.info("About to get fluent speech commands large test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
+ )
diff --git a/egs/gigaspeech/KWS/zipformer/beam_search.py b/egs/gigaspeech/KWS/zipformer/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py
new file mode 100755
index 000000000..149b8bed0
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/decode-asr.py
@@ -0,0 +1,1066 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from gigaspeech_scoring import asr_text_post_processing
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append(line.strip())
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(sp.encode(contexts))
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ test_cuts = gigaspeech.test_cuts()
+ test_dl = gigaspeech.test_dataloaders(test_cuts)
+
+ test_fsc_cuts = gigaspeech.fsc_test_large_cuts()
+ test_fsc_dl = gigaspeech.test_dataloaders(test_fsc_cuts)
+
+ test_sets = ["test", "fsc_test"]
+ test_dls = [test_dl, test_fsc_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py
new file mode 100755
index 000000000..98b003937
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/decode.py
@@ -0,0 +1,689 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --keywords-file keywords.txt \
+ --beam-size 4
+"""
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Dict, List, Optional, Set, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from beam_search import (
+ keywords_search,
+)
+from train import add_model_arguments, get_model, get_params
+
+from lhotse.cut import Cut
+from icefall import ContextGraph
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+@dataclass
+class KwMetric:
+ TP: int = 0 # True positive
+ FN: int = 0 # False negative
+ FP: int = 0 # False positive
+ TN: int = 0 # True negative
+ FN_list: List[str] = field(default_factory=list)
+ FP_list: List[str] = field(default_factory=list)
+ TP_list: List[str] = field(default_factory=list)
+
+ def __str__(self) -> str:
+ return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})"
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--keywords-file",
+ type=str,
+ help="File contains keywords.",
+ )
+
+ parser.add_argument(
+ "--test-set",
+ type=str,
+ default="small",
+ help="small or large",
+ )
+
+ parser.add_argument(
+ "--keywords-score",
+ type=float,
+ default=1.5,
+ help="""
+ The default boosting score (token level) for keywords. it will boost the
+ paths that match keywords to make them survive beam search.
+ """,
+ )
+
+ parser.add_argument(
+ "--keywords-threshold",
+ type=float,
+ default=0.35,
+ help="The default threshold (probability) to trigger the keyword.",
+ )
+
+ parser.add_argument(
+ "--num-tailing-blanks",
+ type=int,
+ default=1,
+ help="The number of tailing blanks should have after hitting one keyword.",
+ )
+
+ parser.add_argument(
+ "--blank-penalty",
+ type=float,
+ default=0.0,
+ help="""
+ The penalty applied on blank symbol during decoding.
+ Note: It is a positive value that would be applied to logits like
+ this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
+ [batch_size, vocab] and blank id is 0).
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ keywords_graph: Optional[ContextGraph] = None,
+) -> List[List[Tuple[str, Tuple[int, int]]]]:
+ """Decode one batch and return the result in a list.
+
+ The length of the list equals to batch size, the i-th element contains the
+ triggered keywords for the i-th utterance in the given batch. The triggered
+ keywords are also a list, each of it contains a tuple of hitting keyword and
+ the corresponding start timestamps and end timestamps of the hitting keyword.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ keywords_graph:
+ The graph containing keywords.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned list.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ ans_dict = keywords_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ keywords_graph=keywords_graph,
+ beam=params.beam,
+ num_tailing_blanks=params.num_tailing_blanks,
+ blank_penalty=params.blank_penalty,
+ )
+
+ hyps = []
+ for ans in ans_dict:
+ hyp = []
+ for hit in ans:
+ hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1])))
+ hyps.append(hyp)
+
+ return hyps
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ keywords_graph: ContextGraph,
+ keywords: Set[str],
+ test_only_keywords: bool,
+) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ keywords_graph:
+ The graph containing keywords.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ log_interval = 50
+
+ results = []
+ metric = {"all": KwMetric()}
+ for k in keywords:
+ metric[k] = KwMetric()
+
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ keywords_graph=keywords_graph,
+ batch=batch,
+ )
+
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_text = ref_text.upper()
+ ref_words = ref_text.split()
+ hyp_words = [x[0] for x in hyp_words]
+ # for computing WER
+ this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
+ hyp_set = set(hyp_words) # each item is a keyword phrase
+ if len(hyp_words) > 1:
+ logging.warning(
+ f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
+ f"please check the transcript to see if it really has more "
+ f"than one keywords, if so consider splitting this audio and"
+ f"keep only one keyword for each audio."
+ )
+ hyp_str = " | ".join(
+ hyp_words
+ ) # The triggered keywords for this utterance.
+ TP = False
+ FP = False
+ for x in hyp_set:
+ assert x in keywords, x # can only trigger keywords
+ if (test_only_keywords and x == ref_text) or (
+ not test_only_keywords and x in ref_text
+ ):
+ TP = True
+ metric[x].TP += 1
+ metric[x].TP_list.append(f"({ref_text} -> {x})")
+ if (test_only_keywords and x != ref_text) or (
+ not test_only_keywords and x not in ref_text
+ ):
+ FP = True
+ metric[x].FP += 1
+ metric[x].FP_list.append(f"({ref_text} -> {x})")
+ if TP:
+ metric["all"].TP += 1
+ if FP:
+ metric["all"].FP += 1
+ TN = True # all keywords are true negative then the summery is true negative.
+ FN = False
+ for x in keywords:
+ if x not in ref_text and x not in hyp_set:
+ metric[x].TN += 1
+ continue
+
+ TN = False
+ if (test_only_keywords and x == ref_text) or (
+ not test_only_keywords and x in ref_text
+ ):
+ fn = True
+ for y in hyp_set:
+ if (test_only_keywords and y == ref_text) or (
+ not test_only_keywords and y in ref_text
+ ):
+ fn = False
+ break
+ if fn:
+ FN = True
+ metric[x].FN += 1
+ metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
+ if TN:
+ metric["all"].TN += 1
+ if FN:
+ metric["all"].FN += 1
+
+ results.extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results, metric
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results: List[Tuple[str, List[str], List[str]]],
+ metric: KwMetric,
+):
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
+
+ with open(metric_filename, "w") as of:
+ width = 10
+ for key, item in sorted(
+ metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
+ ):
+ acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
+ precision = (
+ 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
+ )
+ recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
+ fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
+ s = f"{key}:\n"
+ s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
+ s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
+ s += f"\tAccuracy: {acc:.3f}\n"
+ s += f"\tPrecision: {precision:.3f}\n"
+ s += f"\tRecall(PPR): {recall:.3f}\n"
+ s += f"\tFPR: {fpr:.3f}\n"
+ s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
+ if key != "all":
+ s += f"\tTP list: {' # '.join(item.TP_list)}\n"
+ s += f"\tFP list: {' # '.join(item.FP_list)}\n"
+ s += f"\tFN list: {' # '.join(item.FN_list)}\n"
+ of.write(s + "\n")
+ if key == "all":
+ logging.info(s)
+ of.write(f"\n\n{params.keywords_config}")
+
+ logging.info("Wrote metric stats to {}".format(metric_filename))
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "kws"
+
+ params.suffix = params.test_set
+ if params.iter > 0:
+ params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ params.suffix += f"-score-{params.keywords_score}"
+ params.suffix += f"-threshold-{params.keywords_threshold}"
+ params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
+ if params.blank_penalty != 0:
+ params.suffix += f"-blank-penalty-{params.blank_penalty}"
+ params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ phrases = []
+ token_ids = []
+ keywords_scores = []
+ keywords_thresholds = []
+ keywords_config = []
+ with open(params.keywords_file, "r") as f:
+ for line in f.readlines():
+ keywords_config.append(line)
+ score = 0
+ threshold = 0
+ keyword = []
+ words = line.strip().upper().split()
+ for word in words:
+ word = word.strip()
+ if word[0] == ":":
+ score = float(word[1:])
+ continue
+ if word[0] == "#":
+ threshold = float(word[1:])
+ continue
+ keyword.append(word)
+ keyword = " ".join(keyword)
+ phrases.append(keyword)
+ token_ids.append(sp.encode(keyword))
+ keywords_scores.append(score)
+ keywords_thresholds.append(threshold)
+
+ params.keywords_config = "".join(keywords_config)
+
+ keywords_graph = ContextGraph(
+ context_score=params.keywords_score, ac_threshold=params.keywords_threshold
+ )
+ keywords_graph.build(
+ token_ids=token_ids,
+ phrases=phrases,
+ scores=keywords_scores,
+ ac_thresholds=keywords_thresholds,
+ )
+ keywords = set(phrases)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ test_cuts = gigaspeech.test_cuts()
+ test_dl = gigaspeech.test_dataloaders(test_cuts)
+
+ if params.test_set == "small":
+ test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
+ test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
+ test_sets = ["small-fsc", "test"]
+ test_dls = [test_fsc_small_dl, test_dl]
+ else:
+ assert params.test_set == "large", params.test_set
+ test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
+ test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
+ test_sets = ["large-fsc", "test"]
+ test_dls = [test_fsc_large_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results, metric = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ keywords_graph=keywords_graph,
+ keywords=keywords,
+ test_only_keywords="fsc" in test_set,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results=results,
+ metric=metric,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/decoder.py b/egs/gigaspeech/KWS/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/encoder_interface.py b/egs/gigaspeech/KWS/zipformer/encoder_interface.py
new file mode 120000
index 000000000..653c5b09a
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py
new file mode 120000
index 000000000..2962eb784
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/export.py b/egs/gigaspeech/KWS/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py
new file mode 100755
index 000000000..b8e8802cb
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/finetune.py
@@ -0,0 +1,644 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For non-streaming model training:
+./zipformer/finetune.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/fintune.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+from train import (
+ add_model_arguments,
+ add_training_arguments,
+ compute_loss,
+ compute_validation_loss,
+ display_and_save_batch,
+ get_adjusted_batch_count,
+ get_model,
+ get_params,
+ load_checkpoint_if_available,
+ save_checkpoint,
+ scan_pessimistic_batches_for_oom,
+ set_batch_count,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune.
+ """,
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (a path to a .pt file)",
+ )
+
+ parser.add_argument(
+ "--continue-finetune",
+ type=str2bool,
+ default=False,
+ help="Continue finetuning or finetune from pre-trained model",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ add_training_arguments(parser)
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+
+ # if params.continue_finetune:
+ # set_batch_count(model, params.batch_idx_train)
+ # else:
+ # set_batch_count(model, params.batch_idx_train + 100000)
+
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+
+ if params.continue_finetune:
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+ else:
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules
+ )
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_utt(c: Cut):
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ return T > 0
+
+ gigaspeech = GigaSpeechAsrDataModule(args)
+
+ if params.use_mux:
+ train_cuts = CutSet.mux(
+ gigaspeech.train_cuts(),
+ gigaspeech.fsc_train_cuts(),
+ weights=[0.9, 0.1],
+ )
+ else:
+ train_cuts = gigaspeech.fsc_train_cuts()
+
+ train_cuts = train_cuts.filter(remove_short_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = gigaspeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = gigaspeech.fsc_valid_cuts()
+ valid_cuts = valid_cuts.filter(remove_short_utt)
+ valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics and params.scan_for_oom_batches:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ GigaSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
new file mode 120000
index 000000000..4ee54fff5
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
@@ -0,0 +1 @@
+../../ASR/zipformer/gigaspeech_scoring.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/joiner.py b/egs/gigaspeech/KWS/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/model.py b/egs/gigaspeech/KWS/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/optim.py b/egs/gigaspeech/KWS/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/scaling.py b/egs/gigaspeech/KWS/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/subsampling.py b/egs/gigaspeech/KWS/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py
new file mode 100755
index 000000000..e7387dd39
--- /dev/null
+++ b/egs/gigaspeech/KWS/zipformer/train.py
@@ -0,0 +1,1367 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Yifan Yang,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 8 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import GigaSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="1,1,1,1,1,1",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="192,192,192,192,192,192",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="128,128,128,128,128,128",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="128,128,128,128,128,128",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=320,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=320,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=True,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ add_training_arguments(parser)
+ add_model_arguments(parser)
+
+ return parser
+
+
+def add_training_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=1,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--scan-for-oom-batches",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to scan for oom batches before training, this is helpful for
+ finding the suitable max_duration, you only need to run it once.
+ Caution: a little time consuming.
+ """,
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=8000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 500,
+ "reset_interval": 2000,
+ "valid_interval": 20000,
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ #