From 238b45bea85deee1a07cfd0f55b485cc92f67135 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Thu, 23 Nov 2023 01:22:57 +0800 Subject: [PATCH 01/46] Libriheavy recipe (zipformer) (#1261) * initial commit for libriheavy * Data prepare pipeline * Fix train.py * Fix decode.py * Add results * minor fixes * black * black * Incorporate PR https://github.com/k2-fsa/icefall/pull/1269 --------- Co-authored-by: zr_jin --- egs/libriheavy/ASR/README.md | 6 + egs/libriheavy/ASR/RESULTS.md | 114 +- .../ASR/local/compute_fbank_libriheavy.py | 242 +++ .../ASR/local/compute_fbank_musan.py | 1 + egs/libriheavy/ASR/local/norm_text.py | 58 + egs/libriheavy/ASR/local/prepare_manifest.py | 47 + egs/libriheavy/ASR/local/train_bpe_model.py | 113 ++ egs/libriheavy/ASR/prepare.sh | 314 ++++ .../ASR/zipformer/asr_datamodule.py | 443 ++++++ egs/libriheavy/ASR/zipformer/beam_search.py | 1 + egs/libriheavy/ASR/zipformer/decode.py | 794 +++++++++ egs/libriheavy/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/libriheavy/ASR/zipformer/export-onnx.py | 1 + egs/libriheavy/ASR/zipformer/export.py | 1 + .../ASR/zipformer/jit_pretrained.py | 1 + egs/libriheavy/ASR/zipformer/joiner.py | 1 + egs/libriheavy/ASR/zipformer/model.py | 1 + egs/libriheavy/ASR/zipformer/onnx_decode.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/libriheavy/ASR/zipformer/optim.py | 1 + egs/libriheavy/ASR/zipformer/pretrained.py | 1 + egs/libriheavy/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_coverter.py | 1 + egs/libriheavy/ASR/zipformer/subsampling.py | 1 + .../ASR/zipformer/text_normalization.py | 50 + egs/libriheavy/ASR/zipformer/train.py | 1415 +++++++++++++++++ egs/libriheavy/ASR/zipformer/zipformer.py | 1 + requirements-ci.txt | 1 + requirements.txt | 1 + 30 files changed, 3613 insertions(+), 2 deletions(-) create mode 100644 egs/libriheavy/ASR/README.md create mode 100755 egs/libriheavy/ASR/local/compute_fbank_libriheavy.py create mode 120000 egs/libriheavy/ASR/local/compute_fbank_musan.py create mode 100755 egs/libriheavy/ASR/local/norm_text.py create mode 100755 egs/libriheavy/ASR/local/prepare_manifest.py create mode 100755 egs/libriheavy/ASR/local/train_bpe_model.py create mode 100755 egs/libriheavy/ASR/prepare.sh create mode 100644 egs/libriheavy/ASR/zipformer/asr_datamodule.py create mode 120000 egs/libriheavy/ASR/zipformer/beam_search.py create mode 100644 egs/libriheavy/ASR/zipformer/decode.py create mode 120000 egs/libriheavy/ASR/zipformer/decoder.py create mode 120000 egs/libriheavy/ASR/zipformer/encoder_interface.py create mode 120000 egs/libriheavy/ASR/zipformer/export-onnx.py create mode 120000 egs/libriheavy/ASR/zipformer/export.py create mode 120000 egs/libriheavy/ASR/zipformer/jit_pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/joiner.py create mode 120000 egs/libriheavy/ASR/zipformer/model.py create mode 120000 egs/libriheavy/ASR/zipformer/onnx_decode.py create mode 120000 egs/libriheavy/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/optim.py create mode 120000 egs/libriheavy/ASR/zipformer/pretrained.py create mode 120000 egs/libriheavy/ASR/zipformer/scaling.py create mode 120000 egs/libriheavy/ASR/zipformer/scaling_coverter.py create mode 120000 egs/libriheavy/ASR/zipformer/subsampling.py create mode 100644 egs/libriheavy/ASR/zipformer/text_normalization.py create mode 100644 egs/libriheavy/ASR/zipformer/train.py create mode 120000 egs/libriheavy/ASR/zipformer/zipformer.py diff --git a/egs/libriheavy/ASR/README.md b/egs/libriheavy/ASR/README.md new file mode 100644 index 000000000..2498d017f --- /dev/null +++ b/egs/libriheavy/ASR/README.md @@ -0,0 +1,6 @@ +# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context + +Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105). + + +See [RESULTS](./RESULTS.md) for the results for icefall recipes. diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md index 4fbedad98..513bbf72e 100644 --- a/egs/libriheavy/ASR/RESULTS.md +++ b/egs/libriheavy/ASR/RESULTS.md @@ -1,6 +1,116 @@ -## Results +# Results -### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) +## zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +### Non-streaming + +#### Training on normalized text, i.e. Upper case without punctuation + +##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M + +You can find a pretrained model, training logs at: + + +Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), +exp_small_subset(small set). + +Results of models: + +| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment | +|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| +| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 | +| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 | +| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 | +| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 | +| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 | +| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python ./zipformer/train.py \ + --world-size 4 \ + --master-port 12365 \ + --exp-dir zipformer/exp \ + --num-epochs 60 \ # 16 for large; 90 for small + --lr-hours 15000 \ # 20000 for large; 5000 for small + --use-fp16 1 \ + --start-epoch 1 \ + --bpe-model data/lang_bpe_500/bpe.model \ + --max-duration 1000 \ + --subset medium +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 16 \ + --avg 3 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --causal 0 \ + --decoding-method $m +done +``` + +#### Training on full formatted text, i.e. with casing and punctuation + +##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M + +You can find a pretrained model, training logs at: + + +Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set), +exp_small_subset(small set). + +Results of models: + +| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment | +|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------| +| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 | +| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 | +| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 | + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +python ./zipformer/train.py \ + --world-size 4 \ + --master-port 12365 \ + --exp-dir zipformer/exp \ + --num-epochs 60 \ # 16 for large; 90 for small + --lr-hours 15000 \ # 20000 for large; 10000 for small + --use-fp16 1 \ + --train-with-punctuation 1 \ + --start-epoch 1 \ + --bpe-model data/lang_punc_bpe_756/bpe.model \ + --max-duration 1000 \ + --subset medium +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 16 \ + --avg 3 \ + --exp-dir zipformer/exp \ + --max-duration 1000 \ + --causal 0 \ + --decoding-method $m +done +``` + +## Zipformer PromptASR (zipformer + PromptASR + BERT text encoder) #### [zipformer_prompt_asr](./zipformer_prompt_asr) diff --git a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py new file mode 100755 index 000000000..010531db2 --- /dev/null +++ b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the Libriheavy dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, +) + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-dir", + type=str, + help="""The source directory that contains raw manifests. + """, + default="data/manifests", + ) + + parser.add_argument( + "--fbank-dir", + type=str, + help="""Fbank output dir + """, + default="data/fbank", + ) + + parser.add_argument( + "--subset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Whether to use speed perturbation.", + ) + + parser.add_argument( + "--use-splits", + type=str2bool, + default=False, + help="Whether to compute fbank on splits.", + ) + + parser.add_argument( + "--num-splits", + type=int, + help="""The number of splits of the medium and large subset. + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="""Process pieces starting from this number (inclusive). + Only needed when --use-splits is true.""", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="""Stop processing pieces until this number (exclusive). + Only needed when --use-splits is true.""", + ) + + return parser.parse_args() + + +def compute_fbank_libriheavy(args): + src_dir = Path(args.manifest_dir) + output_dir = Path(args.fbank_dir) + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + subset = args.subset + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + if output_cuts_path.exists(): + logging.info(f"{output_cuts_path} exists - skipping") + return + + input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz" + assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!" + logging.info(f"Loading {input_cuts_path}") + cut_set = CutSet.from_file(input_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + logging.info(f"Saving to {output_cuts_path}") + cut_set.to_file(output_cuts_path) + + +def compute_fbank_libriheavy_splits(args): + num_splits = args.num_splits + subset = args.subset + src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split" + src_dir = Path(src_dir) + output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + num_digits = 8 # num_digits is fixed by lhotse split-lazy + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz" + if not raw_cuts_path.is_file(): + logging.info(f"{raw_cuts_path} does not exist - skipping it") + continue + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists(): + logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca") + os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + overwrite=True, + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + if args.use_splits: + assert args.num_splits is not None, "Please provide num_splits" + compute_fbank_libriheavy_splits(args) + else: + compute_fbank_libriheavy(args) diff --git a/egs/libriheavy/ASR/local/compute_fbank_musan.py b/egs/libriheavy/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/libriheavy/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py new file mode 100755 index 000000000..c2fc0d92d --- /dev/null +++ b/egs/libriheavy/ASR/local/norm_text.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import codecs +import sys + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--text", + type=str, + help="""Path to the input text. + """, + ) + return parser.parse_args() + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def main(): + args = get_args() + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer) + line = f.readline() + while line: + print(remove_punc_to_upper(line)) + line = f.readline() + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py new file mode 100755 index 000000000..42f392cae --- /dev/null +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import json +import sys +from pathlib import Path + + +def simple_cleanup(text: str) -> str: + table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") + text = text.translate(table) + return text.strip() + + +# Assign text of the supervisions and remove unnecessary entries. +def main(): + assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" + fname = Path(sys.argv[1]).name + oname = Path(sys.argv[2]) / fname + with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: + for line in fin: + cut = json.loads(line) + cut["supervisions"][0]["text"] = simple_cleanup( + cut["supervisions"][0]["custom"]["texts"][0] + ) + del cut["supervisions"][0]["custom"] + del cut["custom"] + fout.write((json.dumps(cut) + "\n").encode()) + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..19caf43ab --- /dev/null +++ b/egs/libriheavy/ASR/local/train_bpe_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--byte-fallback", + action="store_true", + help="""Whether to enable byte_fallback when training bpe.""", + ) + + parser.add_argument( + "--character-coverage", + type=float, + default=1.0, + help="Character coverage in vocabulary.", + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=args.character_coverage, + user_defined_symbols=user_defined_symbols, + byte_fallback=args.byte_fallback, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh new file mode 100755 index 000000000..af7e3c5b0 --- /dev/null +++ b/egs/libriheavy/ASR/prepare.sh @@ -0,0 +1,314 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 +export CUDA_VISIBLE_DEVICES="" + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/librilight +# You can find small, medium, large, etc. inside it. +# +# - $dl_dir/libriheavy +# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +fbank_dir=data/fbank +manifests_dir=data/manifests + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download audio data." + # If you have pre-downloaded it to /path/to/librilight, + # you can create a symlink + # + # ln -sfv /path/to/librilight $dl_dir/librilight + # + mkdir -p $dl_dir/librilight + for subset in small medium large; do + log "Downloading ${subset} subset." + if [ ! -d $dl_dir/librilight/${subset} ]; then + wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar + tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight + else + log "Skipping download, ${subset} subset exists." + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download manifests from huggingface." + + # If you have pre-downloaded it to /path/to/libriheavy, + # you can create a symlink + # + # ln -sfv /path/to/libriheavy $dl_dir/libriheavy + # + mkdir -p $dl_dir/libriheavy + for subset in small medium large dev test_clean test_other; do + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Downloading ${subset} subset." + wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz + else + log "Skipping download, ${subset} subset exists." + fi + done + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Download manifests from modelscope" + mkdir -p $dl_dir/libriheavy + if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_small.jsonl.gz ]; then + cd $dl_dir/libriheavy + GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/datasets/pkufool/Libriheavy.git + cd Libriheavy + git lfs pull --exclude "raw/*" + mv *.jsonl.gz ../ + cd .. + rm -rf Libriheavy + cd ../../ + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p $manifests_dir + if [ ! -e $manifests_dir/.musan.done ]; then + lhotse prepare musan $dl_dir/musan $manifests_dir + touch $manifests_dir/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare Libriheavy manifests" + mkdir -p $manifests_dir + for subset in small medium large dev test_clean test_other; do + if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + log "Prepare manifest for subset : ${subset}" + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir + fi + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p $fbank_dir + if [ ! -e $fbank_dir/.musan.done ]; then + ./local/compute_fbank_musan.py + touch $fbank_dir/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for small subset and validation subsets" + for subset in test_clean test_other dev small; do + log "Computing $subset subset." + if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then + ./local/compute_fbank_libriheavy.py \ + --manifest-dir ${manifests_dir} \ + --subset ${subset} \ + --fbank-dir $fbank_dir \ + --num-workers $nj + fi + done +fi + +num_per_split=8000 +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split medium and large subsets." + for subset in medium large; do + log "Spliting subset : $subset" + split_dir=$manifests_dir/libriheavy_${subset}_split + mkdir -p $split_dir + if [ ! -e $split_dir/.split_completed ]; then + lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split + touch $split_dir/.split_completed + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for medium and large subsets" + mkdir -p $fbank_dir + chunk_size=20 + for subset in medium large; do + if [ $subset == "large" ]; then + chunk_size=200 + fi + num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l) + if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then + for i in $(seq 0 1 6); do + start=$(( i * $chunk_size )) + end=$(( (i+1) * $chunk_size )) + ./local/compute_fbank_libriheavy.py \ + --manifest-dir ${manifests_dir} \ + --use-splits 1 \ + --subset ${subset} \ + --fbank-dir $fbank_dir \ + --num-splits $num_splits \ + --num-workers $nj \ + --start $start \ + --stop $end & + done + wait + touch $fbank_dir/.libriheavy.${subset}.done + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Combine features for medium and large subsets." + for subset in medium large; do + log "Combining $subset subset." + if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then + pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz") + lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz + fi + done +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Train BPE model for normalized text" + + if [ ! -f data/texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > data/texts + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + cp data/texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + done +fi + + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Train BPE model for unnormalized text" + if [ ! -f data/punc_texts ]; then + gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts + fi + for vocab_size in ${vocab_sizes[@]}; do + new_vacab_size = $(($vocab_size + 256)) + lang_dir=data/lang_punc_bpe_${new_vocab_size} + mkdir -p $lang_dir + + cp data/punc_texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --byte-fallback \ + --vocab-size ${new_vocab_size} \ + --byte-fallback \ + --character-coverage 0.99 \ + --transcript $lang_dir/text + fi + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare language model for normalized text" + + for subset in small medium large; do + if [ ! -f $manifests_dir/texts_${subset} ]; then + gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \ + | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \ + | ./local/norm_text.py > $manifests_dir/texts_${subset} + fi + done + + mkdir -p data/lm + if [ ! -f data/lm/text ]; then + cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text + fi + + (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \ + > data/lm/words.txt + + cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \ + | awk '{print $1" "NR+3}' >> data/lm/words.txt + + num_lines=$(< data/lm/words.txt wc -l) + (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \ + >> data/lm/words.txt + + # Train LM on transcripts + if [ ! -f data/lm/3-gram.unpruned.arpa ]; then + python3 ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text data/lm/text \ + -lm data/lm/3-gram.unpruned.arpa + fi + + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table=data/lm/words.txt \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + fi +fi + diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..df761c1b8 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,443 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriHeavyAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--subset", + type=str, + default="S", + help="""The subset to be used. Should be S, M or L. Note: S subset + includes libriheavy_cuts_small.jsonl.gz, M subset includes + libriheavy_cuts_small.jsonl.gz and libriheavy_cuts_medium.jsonl.gz, + L subset includes libriheavy_cuts_small.jsonl.gz, + libriheavy_cuts_medium.jsonl.gz and libriheavy_cuts_large.jsonl.gz. + """, + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_small_cuts(self) -> CutSet: + logging.info("About to get small subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" + ) + + @lru_cache() + def train_medium_cuts(self) -> CutSet: + logging.info("About to get medium subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" + ) + + @lru_cache() + def train_large_cuts(self) -> CutSet: + logging.info("About to get large subset cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get the test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get the test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz" + ) diff --git a/egs/libriheavy/ASR/zipformer/beam_search.py b/egs/libriheavy/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py new file mode 100644 index 000000000..1928e2635 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/decode.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from text_normalization import remove_punc_to_upper +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="""Set to True, if the model was trained on texts with casing + and punctuation.""", + ) + + parser.add_argument( + "--post-normalization", + type=str2bool, + default=False, + help="""Upper case and remove all chars except ' and - + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + this_batch = [] + if params.post_normalization and params.train_with_punctuation: + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = remove_punc_to_upper(ref_text).split() + hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[f"{name}_norm"].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libriheavy = LibriHeavyAsrDataModule(args) + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + test_clean_cuts = libriheavy.test_clean_cuts() + test_other_cuts = libriheavy.test_other_cuts() + + if not params.train_with_punctuation: + test_clean_cuts = test_clean_cuts.map(normalize_text) + test_other_cuts = test_other_cuts.map(normalize_text) + + test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts) + test_other_dl = libriheavy.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer/decoder.py b/egs/libriheavy/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/encoder_interface.py b/egs/libriheavy/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export-onnx.py b/egs/libriheavy/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/export.py b/egs/libriheavy/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/jit_pretrained.py b/egs/libriheavy/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/joiner.py b/egs/libriheavy/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/model.py b/egs/libriheavy/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_decode.py b/egs/libriheavy/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/optim.py b/egs/libriheavy/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/pretrained.py b/egs/libriheavy/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling.py b/egs/libriheavy/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/scaling_coverter.py b/egs/libriheavy/ASR/zipformer/scaling_coverter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/scaling_coverter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/subsampling.py b/egs/libriheavy/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py new file mode 100644 index 000000000..92590769c --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/text_normalization.py @@ -0,0 +1,50 @@ +from num2words import num2words + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def word_normalization(word: str) -> str: + # 1. Use full word for some abbreviation + # 2. Convert digits to english words + # 3. Convert ordinal number to english words + if word == "MRS": + return "MISSUS" + if word == "MR": + return "MISTER" + if word == "ST": + return "SAINT" + if word == "ECT": + return "ET CETERA" + + if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH + word = num2words(word[:-2], to="ordinal") + word = word.replace("-", " ") + + if word.isnumeric(): + num = int(word) + if num > 1500 and num < 2030: + word = num2words(word, to="year") + else: + word = num2words(word) + word = word.replace("-", " ") + return word.upper() + + +def text_normalization(text: str) -> str: + text = text.upper() + return " ".join([word_normalization(x) for x in text.split()]) + + +if __name__ == "__main__": + assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK" + assert ( + text_normalization("Hello Mrs st 21st world 3rd she 99th MR") + == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER" + ) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py new file mode 100644 index 000000000..c97da4a11 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -0,0 +1,1415 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriHeavyAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from text_normalization import remove_punc_to_upper +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-hours", + type=float, + default=30000, + help="""Number of hours that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--train-with-punctuation", + type=str2bool, + default=False, + help="If True, the training text will include casing and punctuation.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + # Use the number of hours of speech to adjust the learning rate + scheduler.step_epoch( + params.batch_idx_train * params.max_duration * params.world_size / 3600 + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 2.0 or c.duration > 30.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + libriheavy = LibriHeavyAsrDataModule(args) + + train_cuts = libriheavy.train_small_cuts() + if params.subset == "M" or params.subset == "L": + train_cuts += libriheavy.train_medium_cuts() + if params.subset == "L": + train_cuts += libriheavy.train_large_cuts() + + if not params.train_with_punctuation: + train_cuts = train_cuts.map(normalize_text) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = libriheavy.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = libriheavy.dev_cuts() + + if not params.train_with_punctuation: + valid_cuts = valid_cuts.map(normalize_text) + + valid_dl = libriheavy.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriHeavyAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libriheavy/ASR/zipformer/zipformer.py b/egs/libriheavy/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/libriheavy/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/requirements-ci.txt b/requirements-ci.txt index e1232a768..6c74f688c 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -17,6 +17,7 @@ six git+https://github.com/lhotse-speech/lhotse kaldilm==1.11 kaldialign==0.7.1 +num2words sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 diff --git a/requirements.txt b/requirements.txt index 5a8326619..9502fcbd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ kaldifst kaldilm kaldialign +num2words kaldi-decoder sentencepiece>=0.1.96 tensorboard From ae67f75e9c429d35e8a84d6d70cc8050eae37c86 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 26 Nov 2023 10:04:15 +0800 Subject: [PATCH 02/46] a bilingual recipe similar to the `multi-zh_hans` (#1265) --- ...rmer.sh => run-multi-corpora-zipformer.sh} | 38 + ...er.yml => run-multi-corpora-zipformer.yml} | 10 +- egs/multi_zh_en/ASR/README.md | 19 + egs/multi_zh_en/ASR/RESULTS.md | 44 + egs/multi_zh_en/ASR/local/compile_lg.py | 1 + egs/multi_zh_en/ASR/local/prepare_char.py | 1 + .../ASR/local/prepare_for_bpe_model.py | 65 + egs/multi_zh_en/ASR/local/prepare_lang.py | 1 + .../ASR/local/prepare_lang_bbpe.py | 1 + egs/multi_zh_en/ASR/local/prepare_lang_bpe.py | 1 + egs/multi_zh_en/ASR/local/prepare_words.py | 1 + egs/multi_zh_en/ASR/local/text2segments.py | 1 + egs/multi_zh_en/ASR/local/text2token.py | 1 + egs/multi_zh_en/ASR/local/train_bbpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/multi_zh_en/ASR/prepare.sh | 149 ++ egs/multi_zh_en/ASR/shared | 1 + .../ASR/zipformer/asr_datamodule.py | 385 +++++ egs/multi_zh_en/ASR/zipformer/beam_search.py | 1 + egs/multi_zh_en/ASR/zipformer/decode.py | 851 ++++++++++ egs/multi_zh_en/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/multi_zh_en/ASR/zipformer/export-onnx.py | 1 + egs/multi_zh_en/ASR/zipformer/export.py | 541 +++++++ .../ASR/zipformer/generate_averaged_model.py | 193 +++ .../ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/multi_zh_en/ASR/zipformer/joiner.py | 1 + egs/multi_zh_en/ASR/zipformer/model.py | 1 + .../ASR/zipformer/multi_dataset.py | 247 +++ egs/multi_zh_en/ASR/zipformer/onnx_check.py | 1 + egs/multi_zh_en/ASR/zipformer/onnx_decode.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + .../ASR/zipformer/onnx_pretrained.py | 1 + egs/multi_zh_en/ASR/zipformer/optim.py | 1 + egs/multi_zh_en/ASR/zipformer/pretrained.py | 378 +++++ egs/multi_zh_en/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 1 + egs/multi_zh_en/ASR/zipformer/subsampling.py | 1 + egs/multi_zh_en/ASR/zipformer/train.py | 1416 +++++++++++++++++ egs/multi_zh_en/ASR/zipformer/zipformer.py | 1 + 45 files changed, 4363 insertions(+), 5 deletions(-) rename .github/scripts/{run-multi-zh_hans-zipformer.sh => run-multi-corpora-zipformer.sh} (66%) rename .github/workflows/{run-multi-zh_hans-zipformer.yml => run-multi-corpora-zipformer.yml} (91%) create mode 100644 egs/multi_zh_en/ASR/README.md create mode 100644 egs/multi_zh_en/ASR/RESULTS.md create mode 120000 egs/multi_zh_en/ASR/local/compile_lg.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_char.py create mode 100755 egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_lang_bpe.py create mode 120000 egs/multi_zh_en/ASR/local/prepare_words.py create mode 120000 egs/multi_zh_en/ASR/local/text2segments.py create mode 120000 egs/multi_zh_en/ASR/local/text2token.py create mode 120000 egs/multi_zh_en/ASR/local/train_bbpe_model.py create mode 120000 egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/multi_zh_en/ASR/prepare.sh create mode 120000 egs/multi_zh_en/ASR/shared create mode 100644 egs/multi_zh_en/ASR/zipformer/asr_datamodule.py create mode 120000 egs/multi_zh_en/ASR/zipformer/beam_search.py create mode 100755 egs/multi_zh_en/ASR/zipformer/decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/decoder.py create mode 120000 egs/multi_zh_en/ASR/zipformer/encoder_interface.py create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/export-onnx.py create mode 100755 egs/multi_zh_en/ASR/zipformer/export.py create mode 100755 egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/joiner.py create mode 120000 egs/multi_zh_en/ASR/zipformer/model.py create mode 100644 egs/multi_zh_en/ASR/zipformer/multi_dataset.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_check.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/optim.py create mode 100755 egs/multi_zh_en/ASR/zipformer/pretrained.py create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling.py create mode 120000 egs/multi_zh_en/ASR/zipformer/scaling_converter.py create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py create mode 120000 egs/multi_zh_en/ASR/zipformer/streaming_decode.py create mode 120000 egs/multi_zh_en/ASR/zipformer/subsampling.py create mode 100755 egs/multi_zh_en/ASR/zipformer/train.py create mode 120000 egs/multi_zh_en/ASR/zipformer/zipformer.py diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh similarity index 66% rename from .github/scripts/run-multi-zh_hans-zipformer.sh rename to .github/scripts/run-multi-corpora-zipformer.sh index cbd86a4d3..90f859f43 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-corpora-zipformer.sh @@ -95,3 +95,41 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav done + +rm -rf $repo + +cd ../../../egs/multi_zh_en/ASR +log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \ + --method greedy_search \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \ +$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \ + $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav +done + +rm -rf $repo diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml similarity index 91% rename from .github/workflows/run-multi-zh_hans-zipformer.yml rename to .github/workflows/run-multi-corpora-zipformer.yml index 72c0775a7..38f7eb908 100644 --- a/.github/workflows/run-multi-zh_hans-zipformer.yml +++ b/.github/workflows/run-multi-corpora-zipformer.yml @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: run-multi-zh_hans-zipformer +name: run-multi-corpora-zipformer on: push: @@ -24,12 +24,12 @@ on: types: [labeled] concurrency: - group: run_multi-zh_hans_zipformer-${{ github.ref }} + group: run_multi-corpora_zipformer-${{ github.ref }} cancel-in-progress: true jobs: - run_multi-zh_hans_zipformer: - if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' + run_multi-corpora_zipformer: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' || github.event.label.name == 'multi-corpora' runs-on: ${{ matrix.os }} strategy: matrix: @@ -81,4 +81,4 @@ jobs: export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-multi-zh_hans-zipformer.sh + .github/scripts/run-multi-corpora-zipformer.sh diff --git a/egs/multi_zh_en/ASR/README.md b/egs/multi_zh_en/ASR/README.md new file mode 100644 index 000000000..29341571d --- /dev/null +++ b/egs/multi_zh_en/ASR/README.md @@ -0,0 +1,19 @@ +# Introduction + +This recipe includes scripts for training Zipformer model using both English and Chinese datasets. + +# Included Training Sets + +1. LibriSpeech (English) +2. AiShell-2 (Chinese) +3. TAL-CSASR (Code-Switching, Chinese and English) + +|Datset| Number of hours| URL| +|---|---:|---| +|**TOTAL**|2,547|---| +|LibriSpeech|960|https://www.openslr.org/12/| +|AiShell-2|1,000|http://www.aishelltech.com/aishell_2| +|TAL-CSASR|587|https://ai.100tal.com/openData/voice| + + + diff --git a/egs/multi_zh_en/ASR/RESULTS.md b/egs/multi_zh_en/ASR/RESULTS.md new file mode 100644 index 000000000..3562d6ac3 --- /dev/null +++ b/egs/multi_zh_en/ASR/RESULTS.md @@ -0,0 +1,44 @@ +## Results + +### Zh-En datasets bpe-based training results (Non-streaming) on Zipformer model + +This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1265) in icefall. + +#### Non-streaming (Byte-Level BPE vocab_size=2000) + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 35 \ + --use-fp16 1 \ + --max-duration 1000 \ + --num-workers 8 +``` + +The decoding command: + +``` +for method in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 34 \ + --avg 19 \ + --decoding-method $method +done +``` + +Word Error Rates (WERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model (# tokens is 2000). + +| Datasets | TAL-CSASR | TAL-CSASR | AiShell-2 | AiShell-2 | LibriSpeech | LibriSpeech | +|----------------------|-----------|-----------|-----------|-----------|-------------|-------------| +| Zipformer WER (%) | dev | test | dev | test | test-clean | test-other | +| greedy_search | 6.65 | 6.69 | 6.57 | 7.03 | 2.43 | 5.70 | +| modified_beam_search | 6.46 | 6.51 | 6.18 | 6.60 | 2.41 | 5.57 | +| fast_beam_search | 6.57 | 6.68 | 6.40 | 6.74 | 2.40 | 5.56 | + +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22, which is trained on LibriSpeech 960-hour training set (with speed perturbation), TAL-CSASR training set (with speed perturbation) and AiShell-2 (w/o speed perturbation). + + diff --git a/egs/multi_zh_en/ASR/local/compile_lg.py b/egs/multi_zh_en/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/multi_zh_en/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_char.py b/egs/multi_zh_en/ASR/local/prepare_char.py new file mode 120000 index 000000000..42743b544 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_char.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_char.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py new file mode 100755 index 000000000..00514e6bb --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script tokenizes the training transcript by CJK characters +# and saves the result to transcript_chars.txt, which is used +# to train the BPE model later. + +import argparse +from pathlib import Path + +from tqdm.auto import tqdm + +from icefall.utils import tokenize_by_CJK_char + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Output directory. + The generated transcript_chars.txt is saved to this directory. + """, + ) + + parser.add_argument( + "--text", + type=str, + help="Training transcript.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + text = Path(args.text) + + assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!" + + transcript_path = lang_dir / "transcript_chars.txt" + + with open(text, "r", encoding="utf-8") as fin: + with open(transcript_path, "w+", encoding="utf-8") as fout: + for line in tqdm(fin): + fout.write(tokenize_by_CJK_char(line) + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/local/prepare_lang.py b/egs/multi_zh_en/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py new file mode 120000 index 000000000..9a0b44642 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/prepare_lang_bbpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/prepare_words.py b/egs/multi_zh_en/ASR/local/prepare_words.py new file mode 120000 index 000000000..ef2b4eaf3 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/prepare_words.py @@ -0,0 +1 @@ +../../../aishell2/ASR/local/prepare_words.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2segments.py b/egs/multi_zh_en/ASR/local/text2segments.py new file mode 120000 index 000000000..7d68a39c3 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/text2segments.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2segments.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/text2token.py b/egs/multi_zh_en/ASR/local/text2token.py new file mode 120000 index 000000000..ce5cfd537 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/text2token.py @@ -0,0 +1 @@ +../../../wenetspeech/ASR/local/text2token.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/train_bbpe_model.py b/egs/multi_zh_en/ASR/local/train_bbpe_model.py new file mode 120000 index 000000000..7fb4a9f9d --- /dev/null +++ b/egs/multi_zh_en/ASR/local/train_bbpe_model.py @@ -0,0 +1 @@ +../../../aishell/ASR/local/train_bbpe_model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh new file mode 100755 index 000000000..9f2be5a5c --- /dev/null +++ b/egs/multi_zh_en/ASR/prepare.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +vocab_sizes=( + 2000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +log "Dataset: musan" +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Soft link fbank of musan" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4" + exit 1 + fi +fi + +log "Dataset: LibriSpeech" +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Soft link fbank of LibriSpeech" + mkdir -p data/fbank + if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) . + ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) . + cd ../.. + else + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +log "Dataset: AiShell-2" +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Soft link fbank of AiShell-2" + mkdir -p data/fbank + if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then + cd data/fbank + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts*) . + ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats*) . + cd ../.. + else + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare Byte BPE based lang" + mkdir -p data/fbank + if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then + log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" + exit 1 + fi + + if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then + log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6" + exit 1 + fi + + cd data/ + if [ ! -d ./lang_char ]; then + ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . + fi + if [ ! -d ./lang_bpe_500 ]; then + ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . + fi + cd ../ + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bbpe_${vocab_size} + mkdir -p $lang_dir + + cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ + > $lang_dir/text + + if [ ! -f $lang_dir/transcript_chars.txt ]; then + ./local/prepare_for_bpe_model.py \ + --lang-dir ./$lang_dir \ + --text $lang_dir/text + fi + + if [ ! -f $lang_dir/text_words_segmentation ]; then + python3 ./local/text2segments.py \ + --input-file ./data/lang_char/text \ + --output-file $lang_dir/text_words_segmentation + + cat ./data/lang_bpe_500/transcript_words.txt \ + >> $lang_dir/text_words_segmentation + + cat ./data/lang_char/text \ + >> $lang_dir/text + fi + + cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt + + if [ ! -f $lang_dir/words.txt ]; then + python3 ./local/prepare_words.py \ + --input-file $lang_dir/words_no_ids.txt \ + --output-file $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bbpe.model ]; then + ./local/train_bbpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bbpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bbpe.model + fi + done +fi + diff --git a/egs/multi_zh_en/ASR/shared b/egs/multi_zh_en/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/multi_zh_en/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..be6e94472 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,385 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=300.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl diff --git a/egs/multi_zh_en/ASR/zipformer/beam_search.py b/egs/multi_zh_en/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..8e2c0a65c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py new file mode 100755 index 000000000..e21e8f052 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -0,0 +1,851 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from multi_dataset import MultiDataset +from train import add_model_arguments, get_model, get_params + +from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_2000", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode( + byte_encode(tokenize_by_CJK_char(supervisions["text"])) + ), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [tokenize_by_CJK_char(str(text)).split() for text in texts] + # print(texts) + # exit() + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}" + ) + return T > 0 + + test_sets_cuts = multi_dataset.test_cuts() + + test_sets = test_sets_cuts.keys() + test_dl = [ + data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt)) + for cuts_name in test_sets + ] + + for test_set, test_dl in zip(test_sets, test_dl): + logging.info(f"Start decoding test set: {test_set}") + + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/decoder.py b/egs/multi_zh_en/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..c2eaca671 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/encoder_interface.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx.py b/egs/multi_zh_en/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/export.py b/egs/multi_zh_en/ASR/zipformer/export.py new file mode 100755 index 000000000..fbd9ce0dd --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/export.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 20 \ + --avg 1 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bbpe_2000/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/ + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +import re +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, str2bool + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bbpe_2000/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py new file mode 100755 index 000000000..68111fad7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) use the checkpoint exp_dir/epoch-xxx.pt +./zipformer/generate_averaged_model.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp + +It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`. +You can later load it by `torch.load("epoch-28-avg-15.pt")`. + +(2) use the checkpoint exp_dir/checkpoint-iter.pt +./zipformer/generate_averaged_model.py \ + --iter 22000 \ + --avg 5 \ + --exp-dir ./zipformer/exp + +It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`. +You can later load it by `torch.load("iter-22000-avg-5.pt")`. +""" + + +import argparse +from pathlib import Path + +import k2 +import torch +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + device = torch.device("cpu") + print(f"Device: {device}") + + symbol_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = symbol_table[""] + params.unk_id = symbol_table[""] + params.vocab_size = len(symbol_table) + + print("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/joiner.py b/egs/multi_zh_en/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/model.py b/egs/multi_zh_en/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py new file mode 100644 index 000000000..1155a3dcc --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py @@ -0,0 +1,247 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Dict + +from lhotse import CutSet, load_manifest_lazy + + +class MultiDataset: + def __init__(self, args: argparse.Namespace): + """ + Args: + manifest_dir: + It is expected to contain the following files: + - aishell2_cuts_train.jsonl.gz + """ + self.fbank_dir = Path(args.manifest_dir) + self.use_tal_csasr = args.use_tal_csasr + self.use_librispeech = args.use_librispeech + self.use_aishell2 = args.use_aishell2 + + def train_cuts(self) -> CutSet: + logging.info("About to get multidataset train cuts") + + # AISHELL-2 + if self.use_aishell2: + logging.info("Loading Aishell-2 in lazy mode") + aishell_2_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_train.jsonl.gz" + ) + + # TAL-CSASR + if self.use_tal_csasr: + logging.info("Loading TAL-CSASR in lazy mode") + tal_csasr_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz" + ) + + # LibriSpeech + if self.use_librispeech: + logging.info("Loading LibriSpeech in lazy mode") + train_clean_100_cuts = self.train_clean_100_cuts() + train_clean_360_cuts = self.train_clean_360_cuts() + train_other_500_cuts = self.train_other_500_cuts() + + if self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + weights=[ + len(aishell_2_cuts), + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + ], + ) + elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2: + return CutSet.mux( + aishell_2_cuts, + tal_csasr_cuts, + weights=[ + len(aishell_2_cuts), + len(tal_csasr_cuts), + ], + ) + elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2: + return CutSet.mux( + train_clean_100_cuts, + train_clean_360_cuts, + train_other_500_cuts, + tal_csasr_cuts, + weights=[ + len(train_clean_100_cuts), + len(train_clean_360_cuts), + len(train_other_500_cuts), + len(tal_csasr_cuts), + ], + ) + else: + raise NotImplementedError( + f"""Not implemented for + use_aishell2: {self.use_aishell2} + use_librispeech: {self.use_librispeech} + use_tal_csasr: {self.use_tal_csasr}""" + ) + + def dev_cuts(self) -> CutSet: + logging.info("About to get multidataset dev cuts") + + # AISHELL-2 + logging.info("Loading Aishell-2 DEV set in lazy mode") + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # LibriSpeech + dev_clean_cuts = self.dev_clean_cuts() + dev_other_cuts = self.dev_other_cuts() + + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + + return CutSet.mux( + aishell2_dev_cuts, + dev_clean_cuts, + dev_other_cuts, + tal_csasr_dev_cuts, + weights=[ + len(aishell2_dev_cuts), + len(dev_clean_cuts), + len(dev_other_cuts), + len(tal_csasr_dev_cuts), + ], + ) + + def test_cuts(self) -> Dict[str, CutSet]: + logging.info("About to get multidataset test cuts") + + # AISHELL-2 + if self.use_aishell2: + logging.info("Loading Aishell-2 set in lazy mode") + aishell2_test_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_test.jsonl.gz" + ) + aishell2_dev_cuts = load_manifest_lazy( + self.fbank_dir / "aishell2_cuts_dev.jsonl.gz" + ) + + # LibriSpeech + if self.use_librispeech: + test_clean_cuts = self.test_clean_cuts() + test_other_cuts = self.test_other_cuts() + + logging.info("Loading TAL-CSASR set in lazy mode") + tal_csasr_test_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz" + ) + tal_csasr_dev_cuts = load_manifest_lazy( + self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz" + ) + + test_cuts = { + "tal_csasr_test": tal_csasr_test_cuts, + "tal_csasr_dev": tal_csasr_dev_cuts, + } + + if self.use_aishell2: + test_cuts.update( + { + "aishell-2_test": aishell2_test_cuts, + "aishell-2_dev": aishell2_dev_cuts, + } + ) + if self.use_librispeech: + test_cuts.update( + { + "librispeech_test_clean": test_clean_cuts, + "librispeech_test_other": test_other_cuts, + } + ) + return test_cuts + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_check.py b/egs/multi_zh_en/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/optim.py b/egs/multi_zh_en/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py new file mode 100755 index 000000000..676272e1f --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --epoch 23 \ + --avg 1 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bbpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens ./data/lang_bbpe_2000/tokens.txt \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp/epoch-xx.pt`. + +Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from export import num_tokens +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to byte-level bpe model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + raise ValueError(f"Unsupported method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/multi_zh_en/ASR/zipformer/scaling.py b/egs/multi_zh_en/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py new file mode 120000 index 000000000..13fd02a78 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/subsampling.py b/egs/multi_zh_en/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py new file mode 100755 index 000000000..310c8fe59 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -0,0 +1,1416 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from multi_dataset import MultiDataset +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_2000/bbpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-tal-csasr", + type=str2bool, + default=False, + help="Whether to use TAL-CSASR training data.", + ) + + parser.add_argument( + "--use-librispeech", + type=str2bool, + default=False, + help="Whether to use LibriSpeech training data.", + ) + + parser.add_argument( + "--use-aishell2", + type=str2bool, + default=False, + help="Whether to use Aishell-2 training data.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args) + + train_cuts = multi_dataset.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = data_module.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = multi_dataset.dev_cuts() + valid_dl = data_module.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/multi_zh_en/ASR/zipformer/zipformer.py b/egs/multi_zh_en/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/multi_zh_en/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file From 0622dea30deacf2680dcca0549f7a05c0b965066 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 29 Nov 2023 21:28:38 +0800 Subject: [PATCH 03/46] Add a TTS recipe VITS on LJSpeech dataset (#1372) * first commit * replace phonimizer with g2p * use Conformer as text encoder * modify training script, clean codes * rename directory * convert text to tokens in data preparation stage * fix tts_datamodule.py * support onnx export and testing the exported onnx model * add doc * add README.md * fix style --- .flake8 | 2 +- docs/source/recipes/TTS/index.rst | 7 + docs/source/recipes/TTS/ljspeech/vits.rst | 113 +++ docs/source/recipes/index.rst | 3 +- .../TTS/local/compute_spectrogram_ljspeech.py | 106 ++ .../TTS/local/display_manifest_statistics.py | 73 ++ egs/ljspeech/TTS/local/prepare_token_file.py | 104 ++ .../TTS/local/prepare_tokens_ljspeech.py | 59 ++ egs/ljspeech/TTS/local/validate_manifest.py | 70 ++ egs/ljspeech/TTS/prepare.sh | 117 +++ egs/ljspeech/TTS/shared/parse_options.sh | 1 + egs/ljspeech/TTS/vits/README.md | 3 + egs/ljspeech/TTS/vits/duration_predictor.py | 194 ++++ egs/ljspeech/TTS/vits/export-onnx.py | 261 +++++ egs/ljspeech/TTS/vits/flow.py | 312 ++++++ egs/ljspeech/TTS/vits/generator.py | 531 ++++++++++ egs/ljspeech/TTS/vits/hifigan.py | 933 ++++++++++++++++++ egs/ljspeech/TTS/vits/infer.py | 233 +++++ egs/ljspeech/TTS/vits/loss.py | 336 +++++++ .../TTS/vits/monotonic_align/__init__.py | 81 ++ .../TTS/vits/monotonic_align/core.pyx | 51 + .../TTS/vits/monotonic_align/setup.py | 31 + egs/ljspeech/TTS/vits/posterior_encoder.py | 117 +++ egs/ljspeech/TTS/vits/residual_coupling.py | 229 +++++ egs/ljspeech/TTS/vits/test_onnx.py | 123 +++ egs/ljspeech/TTS/vits/text_encoder.py | 662 +++++++++++++ egs/ljspeech/TTS/vits/tokenizer.py | 106 ++ egs/ljspeech/TTS/vits/train.py | 893 +++++++++++++++++ egs/ljspeech/TTS/vits/transform.py | 218 ++++ egs/ljspeech/TTS/vits/tts_datamodule.py | 325 ++++++ egs/ljspeech/TTS/vits/utils.py | 265 +++++ egs/ljspeech/TTS/vits/vits.py | 610 ++++++++++++ egs/ljspeech/TTS/vits/wavenet.py | 349 +++++++ pyproject.toml | 1 + 34 files changed, 7517 insertions(+), 2 deletions(-) create mode 100644 docs/source/recipes/TTS/index.rst create mode 100644 docs/source/recipes/TTS/ljspeech/vits.rst create mode 100755 egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/display_manifest_statistics.py create mode 100755 egs/ljspeech/TTS/local/prepare_token_file.py create mode 100755 egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/validate_manifest.py create mode 100755 egs/ljspeech/TTS/prepare.sh create mode 120000 egs/ljspeech/TTS/shared/parse_options.sh create mode 100644 egs/ljspeech/TTS/vits/README.md create mode 100644 egs/ljspeech/TTS/vits/duration_predictor.py create mode 100755 egs/ljspeech/TTS/vits/export-onnx.py create mode 100644 egs/ljspeech/TTS/vits/flow.py create mode 100644 egs/ljspeech/TTS/vits/generator.py create mode 100644 egs/ljspeech/TTS/vits/hifigan.py create mode 100755 egs/ljspeech/TTS/vits/infer.py create mode 100644 egs/ljspeech/TTS/vits/loss.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/__init__.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/core.pyx create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/setup.py create mode 100644 egs/ljspeech/TTS/vits/posterior_encoder.py create mode 100644 egs/ljspeech/TTS/vits/residual_coupling.py create mode 100755 egs/ljspeech/TTS/vits/test_onnx.py create mode 100644 egs/ljspeech/TTS/vits/text_encoder.py create mode 100644 egs/ljspeech/TTS/vits/tokenizer.py create mode 100755 egs/ljspeech/TTS/vits/train.py create mode 100644 egs/ljspeech/TTS/vits/transform.py create mode 100644 egs/ljspeech/TTS/vits/tts_datamodule.py create mode 100644 egs/ljspeech/TTS/vits/utils.py create mode 100644 egs/ljspeech/TTS/vits/vits.py create mode 100644 egs/ljspeech/TTS/vits/wavenet.py diff --git a/.flake8 b/.flake8 index 410cb5482..cf276d0ba 100644 --- a/.flake8 +++ b/.flake8 @@ -15,7 +15,7 @@ per-file-ignores = egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203 egs/librispeech/ASR/zipformer/*.py: E501, E203 egs/librispeech/ASR/RESULTS.md: E999, - + egs/ljspeech/TTS/vits/*.py: E501, E203 # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst new file mode 100644 index 000000000..aa891c072 --- /dev/null +++ b/docs/source/recipes/TTS/index.rst @@ -0,0 +1,7 @@ +TTS +====== + +.. toctree:: + :maxdepth: 2 + + ljspeech/vits diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst new file mode 100644 index 000000000..385fd3c70 --- /dev/null +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -0,0 +1,113 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `LJSpeech `_ dataset. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/ljspeech/TTS + $ ./prepare.sh + +To run stage 1 to stage 5, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 5 + + +Build Monotonic Alignment Search +-------------------------------- + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 7265e1cf6..8df61f0d0 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -2,7 +2,7 @@ Recipes ======= This page contains various recipes in ``icefall``. -Currently, only speech recognition recipes are provided. +Currently, we provide recipes for speech recognition, language model, and speech synthesis. We may add recipes for other tasks as well in the future. @@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future. Non-streaming-ASR/index Streaming-ASR/index RNN-LM/index + TTS/index diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py new file mode 100755 index 000000000..97c9008fc --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated spectrogram features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_ljspeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_ljspeech() diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..93f0044f0 --- /dev/null +++ b/egs/ljspeech/TTS/local/display_manifest_statistics.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: + ╒═══════════════════════════╤══════════╕ + │ Cuts count: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Total duration (hh:mm:ss) │ 23:55:18 │ + ├───────────────────────────┼──────────┤ + │ mean │ 6.6 │ + ├───────────────────────────┼──────────┤ + │ std │ 2.2 │ + ├───────────────────────────┼──────────┤ + │ min │ 1.1 │ + ├───────────────────────────┼──────────┤ + │ 25% │ 5.0 │ + ├───────────────────────────┼──────────┤ + │ 50% │ 6.8 │ + ├───────────────────────────┼──────────┤ + │ 75% │ 8.4 │ + ├───────────────────────────┼──────────┤ + │ 99% │ 10.0 │ + ├───────────────────────────┼──────────┤ + │ 99.5% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ 99.9% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ max │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ Recordings available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Features available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Supervisions available: │ 13100 │ + ╘═══════════════════════════╧══════════╛ +""" diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py new file mode 100755 index 000000000..df976804a --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() + + cut_set = load_manifest(manifest_file) + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + for t in cut.tokens: + all_tokens.add(t) + + all_tokens = extra_tokens + list(all_tokens) + + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py new file mode 100755 index 000000000..fcd0137a0 --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest + + +def prepare_tokens_ljspeech(): + output_dir = Path("data/spectrogram") + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_ljspeech() diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py new file mode 100755 index 000000000..68159ae03 --- /dev/null +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh new file mode 100755 index 000000000..8ee40896e --- /dev/null +++ b/egs/ljspeech/TTS/prepare.sh @@ -0,0 +1,117 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=1 +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # The directory $dl_dir/LJSpeech-1.1 will contain: + # - wavs, which contains the audio files + # - metadata.csv, which provides the transcript text for each audio clip + + # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink + # + # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1 + # + if [ ! -d $dl_dir/LJSpeech-1.1 ]; then + lhotse download ljspeech $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LJSpeech manifest" + # We assume that you have downloaded the LJSpeech corpus + # to $dl_dir/LJSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.ljspeech.done ]; then + lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests + touch data/manifests/.ljspeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for LJSpeech" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.ljspeech.done ]; then + ./local/compute_spectrogram_ljspeech.py + touch data/spectrogram/.ljspeech.done + fi + + if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then + log "Validating data/spectrogram for LJSpeech" + python3 ./local/validate_manifest.py \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for LJSpeech" + if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py + mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" + if [ ! -e data/spectrogram/.ljspeech_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_test.jsonl.gz + + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_train.jsonl.gz + touch data/spectrogram/.ljspeech_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ + --tokens data/tokens.txt + fi +fi + + diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh new file mode 120000 index 000000000..e4665e7de --- /dev/null +++ b/egs/ljspeech/TTS/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md new file mode 100644 index 000000000..1141326b9 --- /dev/null +++ b/egs/ljspeech/TTS/vits/README.md @@ -0,0 +1,3 @@ +See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. + +Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29. diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py new file mode 100644 index 000000000..c29a28479 --- /dev/null +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -0,0 +1,194 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Stochastic duration predictor modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from flow import ( + ConvFlow, + DilatedDepthSeparableConv, + ElementwiseAffineFlow, + FlipFlow, + LogFlow, +) + + +class StochasticDurationPredictor(torch.nn.Module): + """Stochastic duration predictor module. + + This is a module of stochastic duration predictor described in `Conditional + Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + flows: int = 4, + dds_conv_layers: int = 3, + global_channels: int = -1, + ): + """Initialize StochasticDurationPredictor module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + dropout_rate (float): Dropout rate. + flows (int): Number of flows. + dds_conv_layers (int): Number of conv layers in DDS conv. + global_channels (int): Number of global conditioning channels. + + """ + super().__init__() + + self.pre = torch.nn.Conv1d(channels, channels, 1) + self.dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.proj = torch.nn.Conv1d(channels, channels, 1) + + self.log_flow = LogFlow() + self.flows = torch.nn.ModuleList() + self.flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.flows += [FlipFlow()] + + self.post_pre = torch.nn.Conv1d(1, channels, 1) + self.post_dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.post_proj = torch.nn.Conv1d(channels, channels, 1) + self.post_flows = torch.nn.ModuleList() + self.post_flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.post_flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.post_flows += [FlipFlow()] + + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T_text). + x_mask (Tensor): Mask tensor (B, 1, T_text). + w (Optional[Tensor]): Duration tensor (B, 1, T_text). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) + inverse (bool): Whether to inverse the flow. + noise_scale (float): Noise scale value. + + Returns: + Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). + If inverse, log-duration tensor (B, 1, T_text). + + """ + x = x.detach() # stop gradient + x = self.pre(x) + if g is not None: + x = x + self.global_conv(g.detach()) # stop gradient + x = self.dds(x, x_mask) + x = self.proj(x) * x_mask + + if not inverse: + assert w is not None, "w must be provided." + h_w = self.post_pre(w) + h_w = self.post_dds(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn( + w.size(0), + 2, + w.size(2), + ).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + logdet_tot_q = 0.0 + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in self.flows: + z, logdet = flow(z, x_mask, g=x, inverse=inverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # (B,) + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn( + x.size(0), + 2, + x.size(2), + ).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, inverse=inverse) + z0, z1 = z.split(1, 1) + logw = z0 + return logw diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py new file mode 100755 index 000000000..154de4bf4 --- /dev/null +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py new file mode 100644 index 000000000..206bd5e3e --- /dev/null +++ b/egs/ljspeech/TTS/vits/flow.py @@ -0,0 +1,312 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Basic Flow modules used in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional, Tuple, Union + +import torch + +from transform import piecewise_rational_quadratic_transform + + +class FlipFlow(torch.nn.Module): + """Flip flow module.""" + + def forward( + self, x: torch.Tensor, *args, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Flipped tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + x = torch.flip(x, [1]) + if not inverse: + logdet = x.new_zeros(x.size(0)) + return x, logdet + else: + return x + + +class LogFlow(torch.nn.Module): + """Log flow module.""" + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + inverse: bool = False, + eps: float = 1e-5, + **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + eps (float): Epsilon for log. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = torch.log(torch.clamp_min(x, eps)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class ElementwiseAffineFlow(torch.nn.Module): + """Elementwise affine flow module.""" + + def __init__(self, channels: int): + """Initialize ElementwiseAffineFlow module. + + Args: + channels (int): Number of channels. + + """ + super().__init__() + self.channels = channels + self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) + self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_lengths (Tensor): Length tensor (B,). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Transpose(torch.nn.Module): + """Transpose module for torch.nn.Sequential().""" + + def __init__(self, dim1: int, dim2: int): + """Initialize Transpose module.""" + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Transpose.""" + return x.transpose(self.dim1, self.dim2) + + +class DilatedDepthSeparableConv(torch.nn.Module): + """Dilated depth-separable conv module.""" + + def __init__( + self, + channels: int, + kernel_size: int, + layers: int, + dropout_rate: float = 0.0, + eps: float = 1e-5, + ): + """Initialize DilatedDepthSeparableConv module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + dropout_rate (float): Dropout rate. + eps (float): Epsilon for layer norm. + + """ + super().__init__() + + self.convs = torch.nn.ModuleList() + for i in range(layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Conv1d( + channels, + channels, + 1, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Dropout(dropout_rate), + ) + ] + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + if g is not None: + x = x + g + for f in self.convs: + y = f(x * x_mask) + x = x + y + return x * x_mask + + +class ConvFlow(torch.nn.Module): + """Convolutional flow module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + layers: int, + bins: int = 10, + tail_bound: float = 5.0, + ): + """Initialize ConvFlow module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + bins (int): Number of bins. + tail_bound (float): Tail bound value. + + """ + super().__init__() + self.half_channels = in_channels // 2 + self.hidden_channels = hidden_channels + self.bins = bins + self.tail_bound = tail_bound + + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.dds_conv = DilatedDepthSeparableConv( + hidden_channels, + kernel_size, + layers, + dropout_rate=0.0, + ) + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * (bins * 3 - 1), + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, 1) + h = self.input_conv(xa) + h = self.dds_conv(h, x_mask, g=g) + h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) + + b, c, t = xa.shape + # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) + + # TODO(kan-bayashi): Understand this calculation + denom = math.sqrt(self.hidden_channels) + unnorm_widths = h[..., : self.bins] / denom + unnorm_heights = h[..., self.bins : 2 * self.bins] / denom + unnorm_derivatives = h[..., 2 * self.bins :] + xb, logdet_abs = piecewise_rational_quadratic_transform( + xb, + unnorm_widths, + unnorm_heights, + unnorm_derivatives, + inverse=inverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = torch.cat([xa, xb], 1) * x_mask + logdet = torch.sum(logdet_abs * x_mask, [1, 2]) + if not inverse: + return x, logdet + else: + return x diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py new file mode 100644 index 000000000..efb0e254c --- /dev/null +++ b/egs/ljspeech/TTS/vits/generator.py @@ -0,0 +1,531 @@ +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Generator module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + +from duration_predictor import StochasticDurationPredictor +from hifigan import HiFiGANGenerator +from posterior_encoder import PosteriorEncoder +from residual_coupling import ResidualAffineCouplingBlock +from text_encoder import TextEncoder +from utils import get_random_segments + + +class VITSGenerator(torch.nn.Module): + """Generator module in VITS, `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + aux_channels: int = 513, + hidden_channels: int = 192, + spks: Optional[int] = None, + langs: Optional[int] = None, + spk_embed_dim: Optional[int] = None, + global_channels: int = -1, + segment_size: int = 32, + text_encoder_attention_heads: int = 2, + text_encoder_ffn_expand: int = 4, + text_encoder_cnn_module_kernel: int = 5, + text_encoder_blocks: int = 6, + text_encoder_dropout_rate: float = 0.1, + decoder_kernel_size: int = 7, + decoder_channels: int = 512, + decoder_upsample_scales: List[int] = [8, 8, 2, 2], + decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], + decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_weight_norm_in_decoder: bool = True, + posterior_encoder_kernel_size: int = 5, + posterior_encoder_layers: int = 16, + posterior_encoder_stacks: int = 1, + posterior_encoder_base_dilation: int = 1, + posterior_encoder_dropout_rate: float = 0.0, + use_weight_norm_in_posterior_encoder: bool = True, + flow_flows: int = 4, + flow_kernel_size: int = 5, + flow_base_dilation: int = 1, + flow_layers: int = 4, + flow_dropout_rate: float = 0.0, + use_weight_norm_in_flow: bool = True, + use_only_mean_in_flow: bool = True, + stochastic_duration_predictor_kernel_size: int = 3, + stochastic_duration_predictor_dropout_rate: float = 0.5, + stochastic_duration_predictor_flows: int = 4, + stochastic_duration_predictor_dds_conv_layers: int = 3, + ): + """Initialize VITS generator module. + + Args: + vocabs (int): Input vocabulary size. + aux_channels (int): Number of acoustic feature channels. + hidden_channels (int): Number of hidden channels. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + global_channels (int): Number of global conditioning channels. + segment_size (int): Segment size for decoder. + text_encoder_attention_heads (int): Number of heads in conformer block + of text encoder. + text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block + of text encoder. + text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder. + text_encoder_blocks (int): Number of conformer blocks in text encoder. + text_encoder_dropout_rate (float): Dropout rate in conformer block of + text encoder. + decoder_kernel_size (int): Decoder kernel size. + decoder_channels (int): Number of decoder initial channels. + decoder_upsample_scales (List[int]): List of upsampling scales in decoder. + decoder_upsample_kernel_sizes (List[int]): List of kernel size for + upsampling layers in decoder. + decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks + in decoder. + decoder_resblock_dilations (List[List[int]]): List of list of dilations for + resblocks in decoder. + use_weight_norm_in_decoder (bool): Whether to apply weight normalization in + decoder. + posterior_encoder_kernel_size (int): Posterior encoder kernel size. + posterior_encoder_layers (int): Number of layers of posterior encoder. + posterior_encoder_stacks (int): Number of stacks of posterior encoder. + posterior_encoder_base_dilation (int): Base dilation of posterior encoder. + posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. + use_weight_norm_in_posterior_encoder (bool): Whether to apply weight + normalization in posterior encoder. + flow_flows (int): Number of flows in flow. + flow_kernel_size (int): Kernel size in flow. + flow_base_dilation (int): Base dilation in flow. + flow_layers (int): Number of layers in flow. + flow_dropout_rate (float): Dropout rate in flow + use_weight_norm_in_flow (bool): Whether to apply weight normalization in + flow. + use_only_mean_in_flow (bool): Whether to use only mean in flow. + stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic + duration predictor. + stochastic_duration_predictor_dropout_rate (float): Dropout rate in + stochastic duration predictor. + stochastic_duration_predictor_flows (int): Number of flows in stochastic + duration predictor. + stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv + layers in stochastic duration predictor. + + """ + super().__init__() + self.segment_size = segment_size + self.text_encoder = TextEncoder( + vocabs=vocabs, + d_model=hidden_channels, + num_heads=text_encoder_attention_heads, + dim_feedforward=hidden_channels * text_encoder_ffn_expand, + cnn_module_kernel=text_encoder_cnn_module_kernel, + num_layers=text_encoder_blocks, + dropout=text_encoder_dropout_rate, + ) + self.decoder = HiFiGANGenerator( + in_channels=hidden_channels, + out_channels=1, + channels=decoder_channels, + global_channels=global_channels, + kernel_size=decoder_kernel_size, + upsample_scales=decoder_upsample_scales, + upsample_kernel_sizes=decoder_upsample_kernel_sizes, + resblock_kernel_sizes=decoder_resblock_kernel_sizes, + resblock_dilations=decoder_resblock_dilations, + use_weight_norm=use_weight_norm_in_decoder, + ) + self.posterior_encoder = PosteriorEncoder( + in_channels=aux_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + kernel_size=posterior_encoder_kernel_size, + layers=posterior_encoder_layers, + stacks=posterior_encoder_stacks, + base_dilation=posterior_encoder_base_dilation, + global_channels=global_channels, + dropout_rate=posterior_encoder_dropout_rate, + use_weight_norm=use_weight_norm_in_posterior_encoder, + ) + self.flow = ResidualAffineCouplingBlock( + in_channels=hidden_channels, + hidden_channels=hidden_channels, + flows=flow_flows, + kernel_size=flow_kernel_size, + base_dilation=flow_base_dilation, + layers=flow_layers, + global_channels=global_channels, + dropout_rate=flow_dropout_rate, + use_weight_norm=use_weight_norm_in_flow, + use_only_mean=use_only_mean_in_flow, + ) + # TODO(kan-bayashi): Add deterministic version as an option + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, + ) + + self.upsample_factor = int(np.prod(decoder_upsample_scales)) + self.spks = None + if spks is not None and spks > 1: + assert global_channels > 0 + self.spks = spks + self.global_emb = torch.nn.Embedding(spks, global_channels) + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + assert global_channels > 0 + self.spk_embed_dim = spk_embed_dim + self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) + self.langs = None + if langs is not None and langs > 1: + assert global_channels > 0 + self.langs = langs + self.lang_emb = torch.nn.Embedding(langs, global_channels) + + # delayed import + from monotonic_align import maximum_path + + self.maximum_path = maximum_path + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: + """Calculate forward propagation. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: Duration negative log-likelihood (NLL) tensor (B,). + Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). + Tensor: Segments start index tensor (B,). + Tensor: Text mask tensor (B, 1, T_text). + Tensor: Feature mask tensor (B, 1, T_feats). + tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + - Tensor: Posterior encoder hidden representation (B, H, T_feats). + - Tensor: Flow hidden representation (B, H, T_feats). + - Tensor: Expanded text encoder projected mean (B, H, T_feats). + - Tensor: Expanded text encoder projected scale (B, H, T_feats). + - Tensor: Posterior encoder projected mean (B, H, T_feats). + - Tensor: Posterior encoder projected scale (B, H, T_feats). + + """ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + + # calculate global conditioning + g = None + if self.spks is not None: + # speaker one-hot vector embedding: (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # language one-hot vector embedding: (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = ( + self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ) + .unsqueeze(1) + .detach() + ) + + # forward duration predictor + w = attn.sum(2) # (B, 1, T_text) + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + z, + feats_lengths, + self.segment_size, + ) + + # forward decoder with random segments + wav = self.decoder(z_segments, g=g) + + return ( + wav, + dur_nll, + attn, + z_start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def inference( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: Optional[torch.Tensor] = None, + feats_lengths: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + dur: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats,). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, + skip the prediction of durations (i.e., teacher forcing). + noise_scale (float): Noise scale parameter for flow. + noise_scale_dur (float): Noise scale parameter for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length of acoustic feature sequence. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Monotonic attention weight tensor (B, T_feats, T_text). + Tensor: Duration tensor (B, T_text). + + """ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + x_mask = x_mask.to(x.dtype) + g = None + if self.spks is not None: + # (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + if use_teacher_forcing: + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ).unsqueeze(1) + dur = attn.sum(2) # (B, 1, T_text) + + # forward decoder with random segments + wav = self.decoder(z * y_mask, g=g) + else: + # duration + if dur is None: + logw = self.duration_predictor( + x, + x_mask, + g=g, + inverse=True, + noise_scale=noise_scale_dur, + ) + w = torch.exp(logw) * x_mask * alpha + dur = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() + y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + y_mask = y_mask.to(x.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = self._generate_path(dur, attn_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul( + attn.squeeze(1), + m_p.transpose(1, 2), + ).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul( + attn.squeeze(1), + logs_p.transpose(1, 2), + ).transpose(1, 2) + + # decoder + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, inverse=True) + wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) + + return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) + + def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate path a.k.a. monotonic attention. + + Args: + dur (Tensor): Duration tensor (B, 1, T_text). + mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). + + Returns: + Tensor: Path tensor (B, 1, T_feats, T_text). + + """ + b, _, t_y, t_x = mask.shape + cum_dur = torch.cumsum(dur, -1) + cum_dur_flat = cum_dur.view(b * t_x) + path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + # path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + path = path.view(b, t_x, t_y).to(dtype=torch.float) + # path will be like (t_x = 3, t_y = 5): + # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + # path = path.to(dtype=mask.dtype) + return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py new file mode 100644 index 000000000..589ac30f6 --- /dev/null +++ b/egs/ljspeech/TTS/vits/hifigan.py @@ -0,0 +1,933 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFi-GAN Modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import copy +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + + +class HiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator module.""" + + def __init__( + self, + in_channels: int = 80, + out_channels: int = 1, + channels: int = 512, + global_channels: int = -1, + kernel_size: int = 7, + upsample_scales: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_additional_convs: bool = True, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + global_channels (int): Number of global conditioning channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (List[int]): List of upsampling scales. + upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. + resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. + resblock_dilations (List[List[int]]): List of list of dilations for residual + blocks. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + + """ + super().__init__() + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernel size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.upsample_factor = int(np.prod(upsample_scales) * out_channels) + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2**i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + c = self.input_conv(c) + if g is not None: + c = c + self.global_conv(g) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m: torch.nn.Module): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def inference( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Perform inference. + + Args: + c (torch.Tensor): Input tensor (T, in_channels). + g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). + + Returns: + Tensor: Output tensor (T ** upsample_factor, out_channels). + + """ + if g is not None: + g = g.unsqueeze(0) + c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g) + return c.squeeze(0).transpose(1, 0) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in HiFiGAN.""" + + def __init__( + self, + kernel_size: int = 3, + channels: int = 512, + dilations: List[int] = [1, 3, 5], + bias: bool = True, + use_additional_convs: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels for convolution layer. + dilations (List[int]): List of dilation factors. + use_additional_convs (bool): Whether to use additional convolution layers. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + + """ + super().__init__() + self.use_additional_convs = use_additional_convs + self.convs1 = torch.nn.ModuleList() + if use_additional_convs: + self.convs2 = torch.nn.ModuleList() + assert kernel_size % 2 == 1, "Kernel size must be odd number." + for dilation in dilations: + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + bias=bias, + padding=(kernel_size - 1) // 2 * dilation, + ), + ) + ] + if use_additional_convs: + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + ) + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + for idx in range(len(self.convs1)): + xt = self.convs1[idx](x) + if self.use_additional_convs: + xt = self.convs2[idx](xt) + x = xt + x + return x + + +class HiFiGANPeriodDiscriminator(torch.nn.Module): + """HiFiGAN period discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + period: int = 3, + kernel_sizes: List[int] = [5, 3], + channels: int = 32, + downsample_scales: List[int] = [3, 3, 3, 3, 1], + max_downsample_channels: int = 1024, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initialize HiFiGANPeriodDiscriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + period (int): Period. + kernel_sizes (list): Kernel sizes of initial conv layers and the final conv + layer. + channels (int): Number of initial channels. + downsample_scales (List[int]): List of downsampling scales. + max_downsample_channels (int): Number of maximum downsampling channels. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." + assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." + + self.period = period + self.convs = torch.nn.ModuleList() + in_chs = in_channels + out_chs = channels + for downsample_scale in downsample_scales: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_chs, + out_chs, + (kernel_sizes[0], 1), + (downsample_scale, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Use downsample_scale + 1? + out_chs = min(out_chs * 4, max_downsample_channels) + self.output_conv = torch.nn.Conv2d( + out_chs, + out_channels, + (kernel_sizes[1] - 1, 1), + 1, + padding=((kernel_sizes[1] - 1) // 2, 0), + ) + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + list: List of each layer's tensors. + + """ + # transform 1d to 2d -> (B, C, T/P, P) + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t += n_pad + x = x.view(b, c, t // self.period, self.period) + + # forward conv + outs = [] + for layer in self.convs: + x = layer(x) + outs += [x] + x = self.output_conv(x) + x = torch.flatten(x, 1, -1) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN multi-period discriminator module.""" + + def __init__( + self, + periods: List[int] = [2, 3, 5, 7, 11], + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initialize HiFiGANMultiPeriodDiscriminator module. + + Args: + periods (List[int]): List of periods. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + for period in periods: + params = copy.deepcopy(discriminator_params) + params["period"] = period + self.discriminators += [HiFiGANPeriodDiscriminator(**params)] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + + return outs + + +class HiFiGANScaleDiscriminator(torch.nn.Module): + """HiFi-GAN scale discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_sizes: List[int] = [15, 41, 5, 3], + channels: int = 128, + max_downsample_channels: int = 1024, + max_groups: int = 16, + bias: int = True, + downsample_scales: List[int] = [2, 2, 4, 4, 1], + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initilize HiFiGAN scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (List[int]): List of four kernel sizes. The first will be used + for the first conv layer, and the second is for downsampling part, and + the remaining two are for the last two output layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling + layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (List[int]): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. If set to true, it + will be applied to all of the conv layers. + + """ + super().__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 4 + for ks in kernel_sizes: + assert ks % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + channels, + # NOTE(kan-bayashi): Use always the same kernel size + kernel_sizes[0], + bias=bias, + padding=(kernel_sizes[0] - 1) // 2, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + out_chs = channels + # NOTE(kan-bayashi): Remove hard coding? + groups = 4 + for downsample_scale in downsample_scales: + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[1], + stride=downsample_scale, + padding=(kernel_sizes[1] - 1) // 2, + groups=groups, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Remove hard coding? + out_chs = min(in_chs * 2, max_downsample_channels) + # NOTE(kan-bayashi): Remove hard coding? + groups = min(groups * 4, max_groups) + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_size=kernel_sizes[3], + stride=1, + padding=(kernel_sizes[3] - 1) // 2, + bias=bias, + ), + ] + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + self.use_weight_norm = use_weight_norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + self.use_spectral_norm = use_spectral_norm + if use_spectral_norm: + self.apply_spectral_norm() + + # backward compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[Tensor]: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def remove_spectral_norm(self): + """Remove spectral normalization module from all of the layers.""" + + def _remove_spectral_norm(m): + try: + logging.debug(f"Spectral norm is removed from {m}.") + torch.nn.utils.remove_spectral_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_spectral_norm) + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Fix the compatibility of weight / spectral normalization issue. + + Some pretrained models are trained with configs that use weight / spectral + normalization, but actually, the norm is not applied. This causes the mismatch + of the parameters with configs. To solve this issue, when parameter mismatch + happens in loading pretrained model, we remove the norm from the current model. + + See also: + - https://github.com/espnet/espnet/pull/5240 + - https://github.com/espnet/espnet/pull/5249 + - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409 + + """ + current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)] + if self.use_weight_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems weight norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_weight_norm() + self.use_weight_norm = False + for k in current_module_keys: + if k.endswith("weight_g") or k.endswith("weight_v"): + del state_dict[k] + + if self.use_spectral_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems spectral norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_spectral_norm() + self.use_spectral_norm = False + for k in current_module_keys: + if ( + k.endswith("weight_u") + or k.endswith("weight_v") + or k.endswith("weight_orig") + ): + del state_dict[k] + + +class HiFiGANMultiScaleDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale discriminator module.""" + + def __init__( + self, + scales: int = 3, + downsample_pooling: str = "AvgPool1d", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = False, + ): + """Initilize HiFiGAN multi-scale discriminator module. + + Args: + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the + inputs. + downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling + module. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm + and the other discriminators use weight norm. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for i in range(scales): + params = copy.deepcopy(discriminator_params) + if follow_official_norm: + if i == 0: + params["use_weight_norm"] = False + params["use_spectral_norm"] = True + else: + params["use_weight_norm"] = True + params["use_spectral_norm"] = False + self.discriminators += [HiFiGANScaleDiscriminator(**params)] + self.pooling = None + if scales > 1: + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[torch.Tensor]]: List of list of each discriminator outputs, + which consists of eachlayer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + if self.pooling is not None: + x = self.pooling(x) + + return outs + + +class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale + multi-period discriminator module.""" + + def __init__( + self, + # Multi-scale discriminator related + scales: int = 3, + scale_downsample_pooling: str = "AvgPool1d", + scale_downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + scale_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = True, + # Multi-period discriminator related + periods: List[int] = [2, 3, 5, 7, 11], + period_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initilize HiFiGAN multi-scale + multi-period discriminator module. + + Args: + scales (int): Number of multi-scales. + scale_downsample_pooling (str): Pooling module name for downsampling of the + inputs. + scale_downsample_pooling_params (dict): Parameters for the above pooling + module. + scale_discriminator_params (dict): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm and + the other discriminators use weight norm. + periods (list): List of periods. + period_discriminator_params (dict): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.msd = HiFiGANMultiScaleDiscriminator( + scales=scales, + downsample_pooling=scale_downsample_pooling, + downsample_pooling_params=scale_downsample_pooling_params, + discriminator_params=scale_discriminator_params, + follow_official_norm=follow_official_norm, + ) + self.mpd = HiFiGANMultiPeriodDiscriminator( + periods=periods, + discriminator_params=period_discriminator_params, + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[Tensor]]: List of list of each discriminator outputs, + which consists of each layer output tensors. Multi scale and + multi period ones are concatenated. + + """ + msd_outs = self.msd(x) + mpd_outs = self.mpd(x) + return msd_outs + mpd_outs diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py new file mode 100755 index 000000000..91a35e360 --- /dev/null +++ b/egs/ljspeech/TTS/vits/infer.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import k2 +import torch +import torch.nn as nn +import torchaudio + +from train import get_model, get_params +from tokenizer import Tokenizer + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger +from tts_datamodule import LJSpeechTtsDataModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + # Background worker save audios to disk. + def _save_worker( + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), + audio[i:i + 1, :audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), + audio_pred[i:i + 1, :audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + + futures.append( + executor.submit( + _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + infer_dataset( + dl=test_dl, + params=params, + model=model, + tokenizer=tokenizer, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py new file mode 100644 index 000000000..21aaad6e7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/loss.py @@ -0,0 +1,336 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFiGAN-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +from typing import List, Tuple, Union + +import torch +import torch.distributions as D +import torch.nn.functional as F + +from lhotse.features.kaldi import Wav2LogFilterBank + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramLoss(torch.nn.Module): + """Mel-spectrogram loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + frame_length: int = 1024, # in samples + frame_shift: int = 256, # in samples + n_mels: int = 80, + use_fft_mag: bool = True, + ): + super().__init__() + self.wav_to_mel = Wav2LogFilterBank( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # in second + frame_shift=frame_shift / sampling_rate, # in second + use_fft_mag=use_fft_mag, + num_filters=n_mels, + ) + + def forward( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + return_mel: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + y_hat (Tensor): Generated waveform tensor (B, 1, T). + y (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_hat = self.wav_to_mel(y_hat.squeeze(1)) + mel = self.wav_to_mel(y.squeeze(1)) + mel_loss = F.l1_loss(mel_hat, mel) + + if return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py + +"""VITS-related loss modules. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +class KLDivergenceLoss(torch.nn.Module): + """KL divergence loss.""" + + def forward( + self, + z_p: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + z_mask: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss. + + Args: + z_p (Tensor): Flow hidden representation (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + z_mask (Tensor): Mask tensor (B, 1, T_feats). + + Returns: + Tensor: KL divergence loss. + + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + loss = kl / torch.sum(z_mask) + + return loss + + +class KLDivergenceLossWithoutFlow(torch.nn.Module): + """KL divergence loss without flow.""" + + def forward( + self, + m_q: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss without flow. + + Args: + m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + """ + posterior_norm = D.Normal(m_q, torch.exp(logs_q)) + prior_norm = D.Normal(m_p, torch.exp(logs_p)) + loss = D.kl_divergence(posterior_norm, prior_norm).mean() + return loss diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py new file mode 100644 index 000000000..2b35654f5 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py @@ -0,0 +1,81 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py + +"""Maximum path calculation module. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import warnings + +import numpy as np +import torch +from numba import njit, prange + +try: + from .core import maximum_path_c + + is_cython_avalable = True +except ImportError: + is_cython_avalable = False + warnings.warn( + "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " + "If you want to use the cython version, please build it as follows: " + "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" + ) + + +def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Calculate maximum path. + + Args: + neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). + attn_mask (Tensor): Attention mask (B, T_feats, T_text). + + Returns: + Tensor: Maximum path tensor (B, T_feats, T_text). + + """ + device, dtype = neg_x_ent.device, neg_x_ent.dtype + neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) + path = np.zeros(neg_x_ent.shape, dtype=np.int32) + t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) + t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) + if is_cython_avalable: + maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) + else: + maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) + + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@njit +def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): + """Calculate a single maximum path with numba.""" + index = t_x - 1 + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@njit(parallel=True) +def maximum_path_numba(paths, values, t_ys, t_xs): + """Calculate batch maximum path with numba.""" + for i in prange(paths.shape[0]): + maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx new file mode 100644 index 000000000..c02c2d02e --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx @@ -0,0 +1,51 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx + +"""Maximum path calculation module with cython optimization. + +This code is copied from https://github.com/jaywalnut310/vits and modifed code format. + +""" + +cimport cython + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py new file mode 100644 index 000000000..33d75e176 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py @@ -0,0 +1,31 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py +"""Setup cython code.""" + +from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] +setup( + name="monotonic_align", + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, +) diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py new file mode 100644 index 000000000..6b8a5be52 --- /dev/null +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -0,0 +1,117 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Posterior encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple + +import torch + +from icefall.utils import make_pad_mask +from wavenet import WaveNet, Conv1d + + +class PosteriorEncoder(torch.nn.Module): + """Posterior encoder module in VITS. + + This is a module of posterior encoder described in `Conditional Variational + Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + in_channels: int = 513, + out_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + layers: int = 16, + stacks: int = 1, + base_dilation: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + ): + """Initilialize PosteriorEncoder module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size in WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of repeat stacking of WaveNet. + base_dilation (int): Base dilation factor. + global_channels (int): Number of global conditioning channels. + dropout_rate (float): Dropout rate. + bias (bool): Whether to use bias parameters in conv. + use_weight_norm (bool): Whether to apply weight norm. + + """ + super().__init__() + + # define modules + self.input_conv = Conv1d(in_channels, hidden_channels, 1) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + self.proj = Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_feats). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). + Tensor: Projected mean tensor (B, out_channels, T_feats). + Tensor: Projected scale tensor (B, out_channels, T_feats). + Tensor: Mask tensor for input tensor (B, 1, T_feats). + + """ + x_mask = ( + (~make_pad_mask(x_lengths)) + .unsqueeze(1) + .to( + dtype=x.dtype, + device=x.device, + ) + ) + x = self.input_conv(x) * x_mask + x = self.encoder(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + + return z, m, logs, x_mask diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py new file mode 100644 index 000000000..2d6807cb7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -0,0 +1,229 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Residual affine coupling modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple, Union + +import torch + +from flow import FlipFlow +from wavenet import WaveNet + + +class ResidualAffineCouplingBlock(torch.nn.Module): + """Residual affine coupling block module. + + This is a module of residual affine coupling block, which used as "Flow" in + `Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + flows: int = 4, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 4, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initilize ResidualAffineCouplingBlock module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + flows (int): Number of flows. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + super().__init__() + + self.flows = torch.nn.ModuleList() + for i in range(flows): + self.flows += [ + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] + self.flows += [FlipFlow()] + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + + """ + if not inverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, inverse=inverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, inverse=inverse) + return x + + +class ResidualAffineCouplingLayer(torch.nn.Module): + """Residual affine coupling layer.""" + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 5, + stacks: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initialzie ResidualAffineCouplingLayer module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + # define modules + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + if use_only_mean: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels, + 1, + ) + else: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * 2, + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, dim=1) + h = self.input_conv(xa) * x_mask + h = self.encoder(h, x_mask, g=g) + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = stats.split(stats.size(1) // 2, dim=1) + else: + m = stats + logs = torch.zeros_like(m) + + if not inverse: + xb = m + xb * torch.exp(logs) * x_mask + x = torch.cat([xa, xb], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * torch.exp(-logs) * x_mask + x = torch.cat([xa, xb], 1) + return x diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py new file mode 100755 index 000000000..8acca7c02 --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +import onnxruntime as ort +import torch +import torchaudio + +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + audio = model(tokens, tokens_lens) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py new file mode 100644 index 000000000..9f337e45b --- /dev/null +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -0,0 +1,662 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text encoder module in VITS. + +This code is based on + - https://github.com/jaywalnut310/vits + - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py + - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py +""" + +import copy +import math +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.utils import is_jit_tracing, make_pad_mask + + +class TextEncoder(torch.nn.Module): + """Text encoder module in VITS. + + This is a module of text encoder described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + cnn_module_kernel: int = 5, + num_layers: int = 6, + dropout: float = 0.1, + ): + """Initialize TextEncoder module. + + Args: + vocabs (int): Vocabulary size. + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + super().__init__() + self.d_model = d_model + + # define modules + self.emb = torch.nn.Embedding(vocabs, d_model) + torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + + # We use conformer as text encoder + self.encoder = Transformer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, + num_layers=num_layers, + dropout=dropout, + ) + + self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input index tensor (B, T_text). + x_lengths (Tensor): Length tensor (B,). + + Returns: + Tensor: Encoded hidden representation (B, attention_dim, T_text). + Tensor: Projected mean tensor (B, attention_dim, T_text). + Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Mask tensor for input tensor (B, 1, T_text). + + """ + # (B, T_text, embed_dim) + x = self.emb(x) * math.sqrt(self.d_model) + + assert x.size(1) == x_lengths.max().item() + + # (B, T_text) + pad_mask = make_pad_mask(x_lengths) + + # encoder assume the channel last (B, T_text, embed_dim) + x = self.encoder(x, key_padding_mask=pad_mask) + + # convert the channel first (B, embed_dim, T_text) + x = x.transpose(1, 2) + non_pad_mask = (~pad_mask).unsqueeze(1) + stats = self.proj(x) * non_pad_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + + return x, m, logs, non_pad_mask + + +class Transformer(nn.Module): + """ + Args: + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + cnn_module_kernel: int = 5, + num_layers: int = 6, + dropout: float = 0.1, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.d_model = d_model + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, + dropout=dropout, + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers) + self.after_norm = nn.LayerNorm(d_model) + + def forward( + self, x: Tensor, key_padding_mask: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + lengths: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + """ + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x = self.encoder( + x, pos_emb, key_padding_mask=key_padding_mask + ) # (T, N, C) + + x = self.after_norm(x) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x + + +class TransformerEncoderLayer(nn.Module): + """ + TransformerEncoderLayer is made up of self-attn and feedforward. + + Args: + d_model: the number of expected features in the input. + num_heads: the number of heads in the multi-head attention models. + dim_feedforward: the dimension of the feed-forward network model. + dropout: the dropout value (default=0.1). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + cnn_module_kernel: int, + dropout: float = 0.1, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + + self.ff_scale = 0.5 + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the transformer encoder layer. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + # macaron style feed-forward module + src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + + # multi-head self-attention module + src_attn = self.self_attn( + self.norm_mha(src), + pos_emb=pos_emb, + key_padding_mask=key_padding_mask, + ) + src = src + self.dropout(src_attn) + + # convolution module + src = src + self.dropout(self.conv_module(self.norm_conv(src))) + + # feed-forward module + src = src + self.dropout(self.feed_forward(self.norm_ff(src))) + + src = self.norm_final(src) + + return src + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer class. + num_layers: the number of sub-encoder-layers in the encoder. + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + key_padding_mask=key_padding_mask, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + x_size = x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, seq_len, 2*seq_len-1). + + Returns: + Tensor: tensor of shape (batch, head, seq_len, seq_len) + """ + (batch_size, num_heads, seq_len, n) = x.shape + + if not is_jit_tracing(): + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, seq_len, seq_len) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: Input tensor of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim) + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + Its shape is (batch_size, seq_len). + + Outputs: + A tensor of shape (seq_len, batch_size, embed_dim). + """ + seq_len, batch_size, _ = x.shape + scaling = float(self.head_dim) ** -0.5 + + q, k, v = self.in_proj(x).chunk(3, dim=-1) + + q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + + q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) + + p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) + p = p.permute(0, 2, 3, 1) + + # (batch_size, num_head, seq_len, head_dim) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + + # (batch_size, num_head, seq_len, seq_len) + attn_output_weights = (matrix_ac + matrix_bd) * scaling + attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size, self.num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + # (batch_size * num_head, seq_len, head_dim) + attn_output = torch.bmm(attn_output_weights, v) + assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + ) + # (seq_len, batch_size, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def _test_text_encoder(): + vocabs = 500 + d_model = 192 + batch_size = 5 + seq_len = 100 + + m = TextEncoder(vocabs=vocabs, d_model=d_model) + x, m, logs, mask = m( + x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)), + x_lengths=torch.full((batch_size,), seq_len), + ) + print(x.shape, m.shape, logs.shape, mask.shape) + + +if __name__ == "__main__": + _test_text_encoder() diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py new file mode 100644 index 000000000..0678b26fe --- /dev/null +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -0,0 +1,106 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import g2p_en +import tacotron_cleaner.cleaners +from utils import intersperse + + +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + self.token2id[token] = id + + self.blank_id = self.token2id[""] + self.oov_id = self.token2id[""] + self.vocab_size = len(self.token2id) + + self.g2p = g2p_en.G2p() + + def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for text in texts: + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = self.g2p(text) + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py new file mode 100755 index 000000000..eb43a4cc9 --- /dev/null +++ b/egs/ljspeech/TTS/vits/train.py @@ -0,0 +1,893 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import numpy as np +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +from tokenizer import Tokenizer +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_image( + "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + ) + tb_writer.add_image( + "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + ) + + if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, :tokens_lens[0].item()] + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) + audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py new file mode 100644 index 000000000..c20d13130 --- /dev/null +++ b/egs/ljspeech/TTS/vits/transform.py @@ -0,0 +1,218 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py + +"""Flow-related transformation. + +This code is derived from https://github.com/bayesiains/nflows. + +""" + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +# TODO(kan-bayashi): Documentation and type hint +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = _searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = _searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet + + +def _searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..0fcbb92c1 --- /dev/null +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -0,0 +1,325 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + SpeechSynthesisDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py new file mode 100644 index 000000000..2a3dae900 --- /dev/null +++ b/egs/ljspeech/TTS/vits/utils.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple, Union +import collections +import logging + +import torch +import torch.nn as nn +import torch.distributed as dist +from lhotse.dataset.sampling.base import CutSampler +from pathlib import Path +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py +def get_random_segments( + x: torch.Tensor, + x_lengths: torch.Tensor, + segment_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get random segments. + + Args: + x (Tensor): Input tensor (B, C, T). + x_lengths (Tensor): Length tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + Tensor: Start index tensor (B,). + + """ + b, c, t = x.size() + max_start_idx = x_lengths - segment_size + max_start_idx[max_start_idx < 0] = 0 + start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( + dtype=torch.long, + ) + segments = get_segments(x, start_idxs, segment_size) + + return segments, start_idxs + + +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py +def get_segments( + x: torch.Tensor, + start_idxs: torch.Tensor, + segment_size: int, +) -> torch.Tensor: + """Get segments. + + Args: + x (Tensor): Input tensor (B, C, T). + start_idxs (Tensor): Start index tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + + """ + b, c, t = x.size() + segments = x.new_zeros(b, c, segment_size) + for i, start_idx in enumerate(start_idxs): + segments[i] = x[i, :, start_idx : start_idx + segment_size] + return segments + + +# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py +def intersperse(sequence, item=0): + result = [item] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result + + +# from https://github.com/jaywalnut310/vits/blob/main/utils.py +MATPLOTLIB_FLAG = False + + +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + samples = "%.2f" % self["samples"] + ans += "over " + str(samples) + " samples." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('loss_1', 0.1), ('loss_2', 0.07)] + """ + samples = self["samples"] if "samples" in self else 1 + ans = [] + for k, v in self.items(): + if k == "samples": + continue + norm_value = float(v) / samples + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + +# checkpoint saving and loading +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py new file mode 100644 index 000000000..d5e20a578 --- /dev/null +++ b/egs/ljspeech/TTS/vits/vits.py @@ -0,0 +1,610 @@ +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""VITS module for GAN-TTS task.""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from hifigan import ( + HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator, +) +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + KLDivergenceLoss, + MelSpectrogramLoss, +) +from utils import get_segments +from generator import VITSGenerator + + +AVAILABLE_GENERATERS = { + "vits_generator": VITSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA +} + + +class VITS(nn.Module): + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` + """ + + def __init__( + self, + # generator related + vocab_size: int, + feature_dim: int = 513, + sampling_rate: int = 22050, + generator_type: str = "vits_generator", + generator_params: Dict[str, Any] = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + }, + # discriminator related + discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any] = { + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + # loss related + generator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params: Dict[str, Any] = { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + lambda_adv: float = 1.0, + lambda_mel: float = 45.0, + lambda_feat_match: float = 2.0, + lambda_dur: float = 1.0, + lambda_kl: float = 1.0, + cache_generator_outputs: bool = True, + ): + """Initialize VITS module. + + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since VITS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator + adversarial loss. + discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for + discriminator adversarial loss. + feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. + mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. + lambda_adv (float): Loss scaling coefficient for adversarial loss. + lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. + lambda_feat_match (float): Loss scaling coefficient for feat match loss. + lambda_dur (float): Loss scaling coefficient for duration loss. + lambda_kl (float): Loss scaling coefficient for KL divergence loss. + cache_generator_outputs (bool): Whether to cache generator outputs. + + """ + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "vits_generator": + # NOTE(kan-bayashi): Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) + self.generator = generator_class( + **generator_params, + ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, + ) + self.generator_adv_loss = GeneratorAdversarialLoss( + **generator_adv_loss_params, + ) + self.discriminator_adv_loss = DiscriminatorAdversarialLoss( + **discriminator_adv_loss_params, + ) + self.feat_match_loss = FeatureMatchLoss( + **feat_match_loss_params, + ) + mel_loss_params.update(sampling_rate=sampling_rate) + self.mel_loss = MelSpectrogramLoss( + **mel_loss_params, + ) + self.kl_loss = KLDivergenceLoss() + + # coefficients + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_kl = lambda_kl + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.sampling_rate = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + forward_generator: bool = True, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + + Returns: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + return_sample=return_sample, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + + def _forward_generator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + if not return_sample: + mel_loss = self.mel_loss(speech_hat_, speech_) + else: + mel_loss, (mel_hat_, mel_) = self.mel_loss( + speech_hat_, speech_, return_mel=True + ) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = torch.sum(dur_nll.float()) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + stats = dict( + generator_loss=loss.item(), + generator_mel_loss=mel_loss.item(), + generator_kl_loss=kl_loss.item(), + generator_dur_loss=dur_loss.item(), + generator_adv_loss=adv_loss.item(), + generator_feat_match_loss=feat_match_loss.item(), + ) + + if return_sample: + stats["returned_sample"] = ( + speech_hat_[0].data.cpu().numpy(), + speech_[0].data.cpu().numpy(), + mel_hat_[0].data.cpu().numpy(), + mel_[0].data.cpu().numpy(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def _forward_discrminator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_.detach()) + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss + + stats = dict( + discriminator_loss=loss.item(), + discriminator_real_loss=real_loss.item(), + discriminator_fake_loss=fake_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def inference( + self, + text: torch.Tensor, + feats: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for single sample. + + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + sids (Tensor): Speaker index tensor (1,). + spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). + lids (Tensor): Language index tensor (1,). + durations (Tensor): Ground-truth duration tensor (T_text,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). + """ + # setup + text = text[None] + text_lengths = torch.tensor( + [text.size(1)], + dtype=torch.long, + device=text.device, + ) + if sids is not None: + sids = sids.view(1) + if lids is not None: + lids = lids.view(1) + if durations is not None: + durations = durations.view(1, 1, -1) + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None].transpose(1, 2) + feats_lengths = torch.tensor( + [feats.size(2)], + dtype=torch.long, + device=feats.device, + ) + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + max_len=max_len, + use_teacher_forcing=use_teacher_forcing, + ) + else: + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + spembs=spembs, + lids=lids, + dur=durations, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav.view(-1), att_w[0], dur[0] + + def inference_batch( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for one batch. + + Args: + text (Tensor): Input text index tensor (B, T_text). + text_lengths (Tensor): Input text index tensor (B,). + sids (Tensor): Speaker index tensor (B,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + + Returns: + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). + """ + # inference + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav, att_w, dur diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py new file mode 100644 index 000000000..fbe1be52b --- /dev/null +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -0,0 +1,349 @@ +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""WaveNet modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import math +import logging + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +class WaveNet(torch.nn.Module): + """WaveNet with global conditioning.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + layers: int = 30, + stacks: int = 3, + base_dilation: int = 2, + residual_channels: int = 64, + aux_channels: int = -1, + gate_channels: int = 128, + skip_channels: int = 64, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + use_first_conv: bool = False, + use_last_conv: bool = False, + scale_residual: bool = False, + scale_skip_connect: bool = False, + ): + """Initialize WaveNet module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + base_dilation (int): Base dilation factor. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for local conditioning feature. + global_channels (int): Number of channels for global conditioning feature. + dropout_rate (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_first_conv (bool): Whether to use the first conv layers. + use_last_conv (bool): Whether to use the last conv layers. + scale_residual (bool): Whether to scale the residual outputs. + scale_skip_connect (bool): Whether to scale the skip connection outputs. + + """ + super().__init__() + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + self.base_dilation = base_dilation + self.use_first_conv = use_first_conv + self.use_last_conv = use_last_conv + self.scale_skip_connect = scale_skip_connect + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + if self.use_first_conv: + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = base_dilation ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + global_channels=global_channels, + dilation=dilation, + dropout_rate=dropout_rate, + bias=bias, + scale_residual=scale_residual, + ) + self.conv_layers += [conv] + + # define output layers + if self.use_last_conv: + self.last_conv = torch.nn.Sequential( + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T) if use_first_conv else + (B, residual_channels, T). + x_mask (Optional[Tensor]): Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T) if use_last_conv else + (B, residual_channels, T). + + """ + # encode to hidden representation + if self.use_first_conv: + x = self.first_conv(x) + + # residual block + skips = 0.0 + for f in self.conv_layers: + x, h = f(x, x_mask=x_mask, c=c, g=g) + skips = skips + h + x = skips + if self.scale_skip_connect: + x = x * math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + if self.use_last_conv: + x = self.last_conv(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size( + layers: int, + stacks: int, + kernel_size: int, + base_dilation: int, + ) -> int: + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self) -> int: + """Return receptive field size.""" + return self._get_receptive_field_size( + self.layers, self.stacks, self.kernel_size, self.base_dilation + ) + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super().__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels: int, out_channels: int, bias: bool): + """Initialize 1x1 Conv1d module.""" + super().__init__( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size: int = 3, + residual_channels: int = 64, + gate_channels: int = 128, + skip_channels: int = 64, + aux_channels: int = 80, + global_channels: int = -1, + dropout_rate: float = 0.0, + dilation: int = 1, + bias: bool = True, + scale_residual: bool = False, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Number of local conditioning channels. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + scale_residual (bool): Whether to scale the residual outputs. + + """ + super().__init__() + self.dropout_rate = dropout_rate + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.scale_residual = scale_residual + + # check + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert gate_channels % 2 == 0 + + # dilation conv + padding = (kernel_size - 1) // 2 * dilation + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # global conditioning + if global_channels > 0: + self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) + else: + self.conv1x1_glo = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + + # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency + # (integrate res 1x1 + skip 1x1 convs) + self.conv1x1_out = Conv1d1x1( + gate_out_channels, residual_channels + skip_channels, bias=bias + ) + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout_rate, training=self.training) + x = self.conv(x) + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + # global conditioning + if g is not None: + g = self.conv1x1_glo(g) + ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ga, xb + gb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # residual + skip 1x1 conv + x = self.conv1x1_out(x) + if x_mask is not None: + x = x * x_mask + + # split integrated conv results + x, s = x.split([self.residual_channels, self.skip_channels], dim=1) + + # for residual connection + x = x + residual + if self.scale_residual: + x = x * math.sqrt(0.5) + + return x, s diff --git a/pyproject.toml b/pyproject.toml index c40143fb9..435256416 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,5 @@ exclude = ''' | icefall\/diagnostics\.py | icefall\/profiler\.py | egs\/librispeech\/ASR\/zipformer + | egs\/ljspeech\/TTS\/vits ''' From f08af2fa2217e226394b6f03442952d104bd984e Mon Sep 17 00:00:00 2001 From: LoganLiu66 <2319277867@qq.com> Date: Mon, 4 Dec 2023 22:29:42 +0800 Subject: [PATCH 04/46] fix initial states (#1398) Co-authored-by: liujiawang02 --- .../pruned_transducer_stateless7_streaming/decode_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py index 0d7e86fcf..2c4b144fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -82,12 +82,12 @@ class DecodeStream(object): self.pad_length = 7 if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() self.hyps.add( Hypothesis( - ys=[params.blank_id] * params.context_size, + ys=[-1] * (params.context_size - 1) + [params.blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) From 735fb9a73dea7d27e95056add6598ae7a282d6f9 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 6 Dec 2023 09:59:19 +0800 Subject: [PATCH 05/46] A TTS recipe VITS on VCTK dataset (#1380) * init * isort formatted * minor updates * Create shared * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare.sh * updated * Update train.py * Update train.py * Update tts_datamodule.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * fixed formatting issue * Update infer.py * removed redundant files * Create monotonic_align * removed redundant files * created symlinks * Update prepare.sh * minor adjustments * Create requirements_tts.txt * Update requirements_tts.txt added version constraints * Update infer.py * Update infer.py * Update infer.py * updated docs * Update export-onnx.py * Update export-onnx.py * Update test_onnx.py * updated requirements.txt * Update test_onnx.py * Update test_onnx.py * docs updated * docs fixed * minor updates --- docs/source/recipes/TTS/index.rst | 1 + docs/source/recipes/TTS/ljspeech/vits.rst | 12 +- docs/source/recipes/TTS/vctk/vits.rst | 125 +++ egs/ljspeech/TTS/prepare.sh | 16 +- egs/ljspeech/TTS/vits/duration_predictor.py | 1 - egs/ljspeech/TTS/vits/export-onnx.py | 8 +- egs/ljspeech/TTS/vits/flow.py | 1 - egs/ljspeech/TTS/vits/generator.py | 5 +- egs/ljspeech/TTS/vits/infer.py | 29 +- egs/ljspeech/TTS/vits/loss.py | 1 - egs/ljspeech/TTS/vits/posterior_encoder.py | 2 +- egs/ljspeech/TTS/vits/residual_coupling.py | 1 - egs/ljspeech/TTS/vits/test_onnx.py | 2 +- egs/ljspeech/TTS/vits/text_encoder.py | 48 +- egs/ljspeech/TTS/vits/tokenizer.py | 4 +- egs/ljspeech/TTS/vits/train.py | 93 +- egs/ljspeech/TTS/vits/tts_datamodule.py | 2 +- egs/ljspeech/TTS/vits/utils.py | 14 +- egs/ljspeech/TTS/vits/vits.py | 9 +- egs/ljspeech/TTS/vits/wavenet.py | 3 +- .../TTS/local/compute_spectrogram_vctk.py | 107 ++ .../TTS/local/display_manifest_statistics.py | 83 ++ egs/vctk/TTS/local/prepare_token_file.py | 104 ++ egs/vctk/TTS/local/prepare_tokens_vctk.py | 61 + egs/vctk/TTS/local/validate_manifest.py | 70 ++ egs/vctk/TTS/prepare.sh | 131 +++ egs/vctk/TTS/shared | 1 + egs/vctk/TTS/vits/duration_predictor.py | 1 + egs/vctk/TTS/vits/export-onnx.py | 284 +++++ egs/vctk/TTS/vits/flow.py | 1 + egs/vctk/TTS/vits/generator.py | 1 + egs/vctk/TTS/vits/hifigan.py | 1 + egs/vctk/TTS/vits/infer.py | 272 +++++ egs/vctk/TTS/vits/loss.py | 1 + egs/vctk/TTS/vits/monotonic_align | 1 + egs/vctk/TTS/vits/posterior_encoder.py | 1 + egs/vctk/TTS/vits/residual_coupling.py | 1 + egs/vctk/TTS/vits/test_onnx.py | 138 +++ egs/vctk/TTS/vits/text_encoder.py | 1 + egs/vctk/TTS/vits/tokenizer.py | 1 + egs/vctk/TTS/vits/train.py | 1000 +++++++++++++++++ egs/vctk/TTS/vits/transform.py | 1 + egs/vctk/TTS/vits/tts_datamodule.py | 338 ++++++ egs/vctk/TTS/vits/utils.py | 1 + egs/vctk/TTS/vits/vits.py | 1 + egs/vctk/TTS/vits/wavenet.py | 1 + requirements-tts.txt | 6 + requirements.txt | 2 + 48 files changed, 2904 insertions(+), 84 deletions(-) create mode 100644 docs/source/recipes/TTS/vctk/vits.rst create mode 100755 egs/vctk/TTS/local/compute_spectrogram_vctk.py create mode 100755 egs/vctk/TTS/local/display_manifest_statistics.py create mode 100755 egs/vctk/TTS/local/prepare_token_file.py create mode 100755 egs/vctk/TTS/local/prepare_tokens_vctk.py create mode 100755 egs/vctk/TTS/local/validate_manifest.py create mode 100755 egs/vctk/TTS/prepare.sh create mode 120000 egs/vctk/TTS/shared create mode 120000 egs/vctk/TTS/vits/duration_predictor.py create mode 100755 egs/vctk/TTS/vits/export-onnx.py create mode 120000 egs/vctk/TTS/vits/flow.py create mode 120000 egs/vctk/TTS/vits/generator.py create mode 120000 egs/vctk/TTS/vits/hifigan.py create mode 100755 egs/vctk/TTS/vits/infer.py create mode 120000 egs/vctk/TTS/vits/loss.py create mode 120000 egs/vctk/TTS/vits/monotonic_align create mode 120000 egs/vctk/TTS/vits/posterior_encoder.py create mode 120000 egs/vctk/TTS/vits/residual_coupling.py create mode 100755 egs/vctk/TTS/vits/test_onnx.py create mode 120000 egs/vctk/TTS/vits/text_encoder.py create mode 120000 egs/vctk/TTS/vits/tokenizer.py create mode 100755 egs/vctk/TTS/vits/train.py create mode 120000 egs/vctk/TTS/vits/transform.py create mode 100644 egs/vctk/TTS/vits/tts_datamodule.py create mode 120000 egs/vctk/TTS/vits/utils.py create mode 120000 egs/vctk/TTS/vits/vits.py create mode 120000 egs/vctk/TTS/vits/wavenet.py create mode 100644 requirements-tts.txt diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst index aa891c072..80d67a2f3 100644 --- a/docs/source/recipes/TTS/index.rst +++ b/docs/source/recipes/TTS/index.rst @@ -5,3 +5,4 @@ TTS :maxdepth: 2 ljspeech/vits + vctk/vits \ No newline at end of file diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 385fd3c70..d08aa0f47 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -4,6 +4,10 @@ VITS This tutorial shows you how to train an VITS model with the `LJSpeech `_ dataset. +.. note:: + + TTS related recipes require packages in ``requirements-tts.txt``. + .. note:: The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ @@ -27,6 +31,12 @@ To run stage 1 to stage 5, use Build Monotonic Alignment Search -------------------------------- +.. code-block:: bash + + $ ./prepare.sh --stage -1 --stop_stage -1 + +or + .. code-block:: bash $ cd vits/monotonic_align @@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir $ ./vits/infer.py \ --epoch 1000 \ --exp-dir vits/exp \ - --tokens data/tokens.txt + --tokens data/tokens.txt \ --max-duration 500 .. note:: diff --git a/docs/source/recipes/TTS/vctk/vits.rst b/docs/source/recipes/TTS/vctk/vits.rst new file mode 100644 index 000000000..34024a5ea --- /dev/null +++ b/docs/source/recipes/TTS/vctk/vits.rst @@ -0,0 +1,125 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `VCTK `_ dataset. + +.. note:: + + TTS related recipes require packages in ``requirements-tts.txt``. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/vctk/TTS + $ ./prepare.sh + +To run stage 1 to stage 6, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 6 + + +Build Monotonic Alignment Search +-------------------------------- + +To build the monotonic alignment search, use the following commands: + +.. code-block:: bash + + $ ./prepare.sh --stage -1 --stop_stage -1 + +or + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 350 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt \ + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + +Download pretrained models +-------------------------- + +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 8ee40896e..ed0a07f5e 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -nj=1 -stage=-1 +stage=0 stop_stage=100 dl_dir=$PWD/download @@ -25,6 +24,17 @@ log() { log "dl_dir: $dl_dir" +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" @@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --tokens data/tokens.txt fi fi - - diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py index c29a28479..1a8190014 100644 --- a/egs/ljspeech/TTS/vits/duration_predictor.py +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -14,7 +14,6 @@ from typing import Optional import torch import torch.nn.functional as F - from flow import ( ConvFlow, DilatedDepthSeparableConv, diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 154de4bf4..2068adeea 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -180,7 +180,13 @@ def export_model_onnx( model_filename, verbose=False, opset_version=opset_version, - input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "alpha", + ], output_names=["audio"], dynamic_axes={ "tokens": {0: "N", 1: "T"}, diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py index 206bd5e3e..2b84f6434 100644 --- a/egs/ljspeech/TTS/vits/flow.py +++ b/egs/ljspeech/TTS/vits/flow.py @@ -13,7 +13,6 @@ import math from typing import Optional, Tuple, Union import torch - from transform import piecewise_rational_quadratic_transform diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index efb0e254c..66c8cedb1 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -16,9 +16,6 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F - -from icefall.utils import make_pad_mask - from duration_predictor import StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder @@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder from utils import get_random_segments +from icefall.utils import make_pad_mask + class VITSGenerator(torch.nn.Module): """Generator module in VITS, `Conditional Variational Autoencoder diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 91a35e360..cf0d20ae2 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -36,13 +36,12 @@ import k2 import torch import torch.nn as nn import torchaudio - -from train import get_model, get_params from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger -from tts_datamodule import LJSpeechTtsDataModule def get_parser(): @@ -107,12 +106,12 @@ def infer_dataset( for i in range(batch_size): torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), - audio[i:i + 1, :audio_lens[i]], + audio[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), - audio_pred[i:i + 1, :audio_lens_pred[i]], + audio_pred[i : i + 1, : audio_lens_pred[i]], sample_rate=params.sampling_rate, ) @@ -144,14 +143,24 @@ def infer_dataset( audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] - audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred, _, durations = model.inference_batch( + text=tokens, text_lengths=tokens_lens + ) audio_pred = audio_pred.detach().cpu() # convert to samples - audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) futures.append( executor.submit( - _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + _save_worker, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, ) ) @@ -160,7 +169,9 @@ def infer_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) # return results for f in futures: f.result() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py index 21aaad6e7..2f4dc9bc0 100644 --- a/egs/ljspeech/TTS/vits/loss.py +++ b/egs/ljspeech/TTS/vits/loss.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F - from lhotse.features.kaldi import Wav2LogFilterBank diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py index 6b8a5be52..1104fb864 100644 --- a/egs/ljspeech/TTS/vits/posterior_encoder.py +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple import torch +from wavenet import Conv1d, WaveNet from icefall.utils import make_pad_mask -from wavenet import WaveNet, Conv1d class PosteriorEncoder(torch.nn.Module): diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py index 2d6807cb7..f9a2a3786 100644 --- a/egs/ljspeech/TTS/vits/residual_coupling.py +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple, Union import torch - from flow import FlipFlow from wavenet import WaveNet diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 8acca7c02..686fee2a0 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -28,10 +28,10 @@ Use the onnx model to generate a wav: import argparse import logging + import onnxruntime as ort import torch import torchaudio - from tokenizer import Tokenizer diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index 9f337e45b..fcbae7103 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -169,9 +169,7 @@ class Transformer(nn.Module): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - x = self.encoder( - x, pos_emb, key_padding_mask=key_padding_mask - ) # (T, N, C) + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) x = self.after_norm(x) @@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = RelPositionMultiheadAttention( + d_model, num_heads, dropout=dropout + ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module): key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ # macaron style feed-forward module - src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + src = src + self.ff_scale * self.dropout( + self.feed_forward_macaron(self.norm_ff_macaron(src)) + ) # multi-head self-attention module src_attn = self.self_attn( @@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module): q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) - v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + v = ( + v.contiguous() + .view(seq_len, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) - p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + p = self.linear_pos(pos_emb).view( + pos_emb.size(0), -1, self.num_heads, self.head_dim + ) # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) p = p.permute(0, 2, 3, 1) @@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch_size, num_head, seq_len, seq_len) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) - matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift( + matrix_bd + ) # (batch_size, num_head, seq_len, seq_len) # (batch_size, num_head, seq_len, seq_len) attn_output_weights = (matrix_ac + matrix_bd) * scaling - attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len) @@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module): # (batch_size * num_head, seq_len, head_dim) attn_output = torch.bmm(attn_output_weights, v) - assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + assert attn_output.shape == ( + batch_size * self.num_heads, + seq_len, + self.head_dim, + ) attn_output = ( - attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, batch_size, self.embed_dim) ) # (seq_len, batch_size, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 0678b26fe..70f1240b4 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -78,7 +78,9 @@ class Tokenizer(object): return token_ids_list - def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + def tokens_to_token_ids( + self, tokens_list: List[str], intersperse_blank: bool = True + ): """ Args: tokens_list: diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index eb43a4cc9..71c4224fa 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -18,21 +18,25 @@ import argparse import logging -import numpy as np from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed -from torch.optim import Optimizer +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -41,11 +45,6 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, setup_logger, str2bool -from tokenizer import Tokenizer -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -385,11 +384,12 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size try: with autocast(enabled=params.use_fp16): @@ -446,7 +446,9 @@ def train_one_epoch( # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -482,9 +484,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train @@ -492,19 +492,34 @@ def train_one_epoch( if "returned_sample" in stats_g: speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] tb_writer.add_audio( - "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_image( - "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", ) tb_writer.add_image( - "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", ) - if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info, (speech_hat, speech) = compute_validation_loss( params=params, @@ -523,10 +538,16 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] @@ -555,11 +576,17 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size # forward discriminator loss_d, stats_d = model( @@ -596,12 +623,17 @@ def compute_validation_loss( if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model audio_pred, _, duration = inner_model.inference( - text=tokens[0, :tokens_lens[0].item()] + text=tokens[0, : tokens_lens[0].item()] ) audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) - audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() returned_sample = (audio_pred, audio_gt) if world_size > 1: @@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) try: # for discriminator with autocast(enabled=params.use_fp16): diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 0fcbb92c1..81bb9ed13 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, - SpeechSynthesisDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, + SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py index 2a3dae900..6a067f596 100644 --- a/egs/ljspeech/TTS/vits/utils.py +++ b/egs/ljspeech/TTS/vits/utils.py @@ -14,15 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from pathlib import Path from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -97,23 +97,23 @@ def plot_feature(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib + matplotlib.use("Agg") MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') + mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index d5e20a578..b4f0c21e6 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from torch.cuda.amp import autocast - +from generator import VITSGenerator from hifigan import ( HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, @@ -25,9 +24,8 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) +from torch.cuda.amp import autocast from utils import get_segments -from generator import VITSGenerator - AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, @@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = { class VITS(nn.Module): - """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` - """ + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" def __init__( self, diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py index fbe1be52b..5db461d5c 100644 --- a/egs/ljspeech/TTS/vits/wavenet.py +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -import math import logging - +import math from typing import Optional, Tuple import torch diff --git a/egs/vctk/TTS/local/compute_spectrogram_vctk.py b/egs/vctk/TTS/local/compute_spectrogram_vctk.py new file mode 100755 index 000000000..440ac1245 --- /dev/null +++ b/egs/vctk/TTS/local/compute_spectrogram_vctk.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the VCTK dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_vctk(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(32, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + ).resample(sampling_rate=sampling_rate) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_vctk() diff --git a/egs/vctk/TTS/local/display_manifest_statistics.py b/egs/vctk/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..0472e2cea --- /dev/null +++ b/egs/vctk/TTS/local/display_manifest_statistics.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/vctk_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 41:02:18 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 1.2 │ +├───────────────────────────┼──────────┤ +│ min │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 2.6 │ +├───────────────────────────┼──────────┤ +│ 50% │ 3.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 3.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 8.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 12.1 │ +├───────────────────────────┼──────────┤ +│ max │ 16.6 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 43873 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 43873 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 41:02:18 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:01 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py new file mode 100755 index 000000000..c6636c3ad --- /dev/null +++ b/egs/vctk/TTS/local/prepare_token_file.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() + + cut_set = load_manifest(manifest_file) + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + for t in cut.tokens: + all_tokens.add(t) + + all_tokens = extra_tokens + list(all_tokens) + + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py new file mode 100755 index 000000000..32e1c7dfa --- /dev/null +++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest +from tqdm.auto import tqdm + + +def prepare_tokens_vctk(): + output_dir = Path("data/spectrogram") + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in tqdm(cut_set): + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_vctk() diff --git a/egs/vctk/TTS/local/validate_manifest.py b/egs/vctk/TTS/local/validate_manifest.py new file mode 100755 index 000000000..cd466303e --- /dev/null +++ b/egs/vctk/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh new file mode 100755 index 000000000..87150ad31 --- /dev/null +++ b/egs/vctk/TTS/prepare.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/VCTK, + # you can create a symlink + # + # ln -sfv /path/to/VCTK $dl_dir/VCTK + # + if [ ! -d $dl_dir/VCTK ]; then + lhotse download vctk $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare VCTK manifest" + # We assume that you have downloaded the VCTK corpus + # to $dl_dir/VCTK + mkdir -p data/manifests + if [ ! -e data/manifests/.vctk.done ]; then + lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests + touch data/manifests/.vctk.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for VCTK" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.vctk.done ]; then + ./local/compute_spectrogram_vctk.py + touch data/spectrogram/.vctk.done + fi + + if [ ! -e data/spectrogram/.vctk-validated.done ]; then + log "Validating data/fbank for VCTK" + ./local/validate_manifest.py \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for VCTK" + if [ ! -e data/spectrogram/.vctk_with_token.done ]; then + ./local/prepare_tokens_vctk.py + mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the VCTK cuts into train, valid and test sets" + if [ ! -e data/spectrogram/.vctk_split.done ]; then + lhotse subset --last 600 \ + data/spectrogram/vctk_cuts_all.jsonl.gz \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz \ + data/spectrogram/vctk_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/vctk_cuts_validtest.jsonl.gz \ + data/spectrogram/vctk_cuts_test.jsonl.gz + + rm data/spectrogram/vctk_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/vctk_cuts_all.jsonl.gz \ + data/spectrogram/vctk_cuts_train.jsonl.gz + touch data/spectrogram/.vctk_split.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \ + --tokens data/tokens.txt + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate speakers file" + if [ ! -e data/speakers.txt ]; then + gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \ + | jq '.speaker' | sed 's/"//g' \ + | sort | uniq > data/speakers.txt + fi +fi diff --git a/egs/vctk/TTS/shared b/egs/vctk/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/vctk/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py new file mode 120000 index 000000000..9972b476f --- /dev/null +++ b/egs/vctk/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py new file mode 100755 index 000000000..7c9664cc1 --- /dev/null +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + speaker: int = 20, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + speaker (int): + Speaker ID. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + sids=speaker, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + speaker = torch.tensor([1], dtype=torch.int64) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "speaker", + "alpha", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + "speaker": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py new file mode 120000 index 000000000..e65d91ea7 --- /dev/null +++ b/egs/vctk/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py new file mode 120000 index 000000000..611679bfa --- /dev/null +++ b/egs/vctk/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/hifigan.py b/egs/vctk/TTS/vits/hifigan.py new file mode 120000 index 000000000..5ac025de7 --- /dev/null +++ b/egs/vctk/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py new file mode 100755 index 000000000..06c25f02e --- /dev/null +++ b/egs/vctk/TTS/vits/infer.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +import torch.nn as nn +import torchaudio +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import VctkTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, + speaker_map: Dict[str, int], +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), + audio_pred[i : i + 1, : audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]) + .int() + .to(device) + ) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch( + text=tokens, + text_lengths=tokens_lens, + sids=speakers, + ) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + # we need cut ids to display recognition results. + args.return_cuts = True + vctk = VctkTtsDataModule(args) + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + test_cuts = vctk.test_cuts() + test_dl = vctk.test_dataloaders(test_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + infer_sets = {"test": test_dl, "valid": valid_dl} + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + tokenizer=tokenizer, + speaker_map=speaker_map, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py new file mode 120000 index 000000000..672e5ff68 --- /dev/null +++ b/egs/vctk/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/monotonic_align b/egs/vctk/TTS/vits/monotonic_align new file mode 120000 index 000000000..71934e7cc --- /dev/null +++ b/egs/vctk/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py new file mode 120000 index 000000000..41d64a3a6 --- /dev/null +++ b/egs/vctk/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py new file mode 120000 index 000000000..f979adbf0 --- /dev/null +++ b/egs/vctk/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py new file mode 100755 index 000000000..757e67fc1 --- /dev/null +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +from pathlib import Path + +import onnxruntime as ort +import torch +import torchaudio +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: speaker.numpy(), + self.model.get_inputs()[5].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + args.num_spks = len(speaker_map) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + speaker = torch.tensor([1], dtype=torch.int64) # (1, ) + audio = model(tokens, tokens_lens, speaker) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py new file mode 120000 index 000000000..0efba277e --- /dev/null +++ b/egs/vctk/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py new file mode 120000 index 000000000..057b0dc4b --- /dev/null +++ b/egs/vctk/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py new file mode 100755 index 000000000..56f167a17 --- /dev/null +++ b/egs/vctk/TTS/vits/train.py @@ -0,0 +1,1000 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import VctkTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": params.num_spks, + "langs": None, + "spk_embed_dim": None, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: Dict[str, int], +): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + speakers = ( + torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) + ) + + tokens = tokenizer.tokens_to_token_ids(tokens) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + speaker_map=speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + speaker_map: Dict[str, int], + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + sids=speakers[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + speaker_map: Dict[str, int], + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + speakers, + ) = prepare_input(batch, tokenizer, device, speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + sids=speakers, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + vctk = VctkTtsDataModule(args) + + train_cuts = vctk.train_cuts() + speaker_map = vctk.speakers() + params.num_spks = len(speaker_map) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = vctk.train_dataloaders(train_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + speaker_map=speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + speaker_map=speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/vctk/TTS/vits/transform.py b/egs/vctk/TTS/vits/transform.py new file mode 120000 index 000000000..962647408 --- /dev/null +++ b/egs/vctk/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..8b2a96b09 --- /dev/null +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -0,0 +1,338 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class VctkTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + + @lru_cache() + def speakers(self) -> Dict[str, int]: + logging.info("About to get speakers") + with open(self.args.speakers) as f: + speakers = {line.strip(): i for i, line in enumerate(f)} + return speakers diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py new file mode 120000 index 000000000..085e764b4 --- /dev/null +++ b/egs/vctk/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py new file mode 120000 index 000000000..1f58cf6fe --- /dev/null +++ b/egs/vctk/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py new file mode 120000 index 000000000..28f0a78ee --- /dev/null +++ b/egs/vctk/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/requirements-tts.txt b/requirements-tts.txt new file mode 100644 index 000000000..c30e23d54 --- /dev/null +++ b/requirements-tts.txt @@ -0,0 +1,6 @@ +# for TTS recipes +matplotlib==3.8.2 +cython==3.0.6 +numba==0.58.1 +g2p_en==2.1.0 +espnet_tts_frontend==0.0.3 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9502fcbd2..a1a46ae64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ tensorboard typeguard dill black==22.3.0 +onnx==1.15.0 +onnxruntime==1.16.3 \ No newline at end of file From b87ed26c09e9f5bb29174dd01f13670fb6124583 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:33:45 +0800 Subject: [PATCH 06/46] Normalize dockerfile (#1400) --- docker/torch1.12.1-cuda11.3.dockerfile | 3 +-- docker/torch1.13.0-cuda11.6.dockerfile | 3 +-- docker/torch1.9.0-cuda10.2.dockerfile | 3 +-- docker/torch2.0.0-cuda11.7.dockerfile | 3 +-- docker/torch2.1.0-cuda11.8.dockerfile | 3 +-- docker/torch2.1.0-cuda12.1.dockerfile | 3 +-- 6 files changed, 6 insertions(+), 12 deletions(-) diff --git a/docker/torch1.12.1-cuda11.3.dockerfile b/docker/torch1.12.1-cuda11.3.dockerfile index ed746abe3..deb5715cc 100644 --- a/docker/torch1.12.1-cuda11.3.dockerfile +++ b/docker/torch1.12.1-cuda11.3.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch1.13.0-cuda11.6.dockerfile b/docker/torch1.13.0-cuda11.6.dockerfile index 9657866e5..afc6c1b84 100644 --- a/docker/torch1.13.0-cuda11.6.dockerfile +++ b/docker/torch1.13.0-cuda11.6.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch1.9.0-cuda10.2.dockerfile b/docker/torch1.9.0-cuda10.2.dockerfile index a92af7ad0..9ff225b54 100644 --- a/docker/torch1.9.0-cuda10.2.dockerfile +++ b/docker/torch1.9.0-cuda10.2.dockerfile @@ -25,7 +25,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -58,7 +58,6 @@ RUN pip uninstall -y tqdm && \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index 07296e6f0..db8076560 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index e500e9a6a..b006b0d96 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index c3f12323e..1b078dc22 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -18,7 +18,7 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ curl \ vim \ - libssl-dev \ + libssl-dev \ autoconf \ automake \ bzip2 \ @@ -44,7 +44,6 @@ RUN pip install --no-cache-dir \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ - \ kaldi_native_io \ kaldialign \ kaldifst \ From bda72f86fffe591d334630da522dba4cf5c66341 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Dec 2023 06:32:40 +0800 Subject: [PATCH 07/46] minor adjustments to the VITS recipes for onnx runtime (#1405) --- egs/ljspeech/TTS/vits/export-onnx.py | 4 ++-- egs/ljspeech/TTS/vits/test_onnx.py | 4 ++-- egs/vctk/TTS/vits/export-onnx.py | 4 ++-- egs/vctk/TTS/vits/test_onnx.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 2068adeea..bca6aec99 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -176,7 +176,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), model_filename, verbose=False, opset_version=opset_version, @@ -184,8 +184,8 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", - "noise_scale_dur", "alpha", + "noise_scale_dur", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 686fee2a0..fcbc1d663 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -92,8 +92,8 @@ class OnnxModel: self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), }, )[0] return torch.from_numpy(out) diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 7c9664cc1..cfc74fd0a 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -187,7 +187,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), model_filename, verbose=False, opset_version=opset_version, @@ -195,9 +195,9 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", + "alpha", "noise_scale_dur", "speaker", - "alpha", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index 757e67fc1..d85c0a27b 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -101,9 +101,9 @@ class OnnxModel: self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: speaker.numpy(), - self.model.get_inputs()[5].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + self.model.get_inputs()[5].name: speaker.numpy(), }, )[0] return torch.from_numpy(out) From e9ec827de76856e38af7a884b878ca3a84f64bb9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Dec 2023 14:29:24 +0800 Subject: [PATCH 08/46] Rename zipformer2 to zipformer_for_ncnn_export_only to avoid confusion. (#1407) --- .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../ASR/pruned_transducer_stateless7_streaming/zipformer2.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 + .../do_not_use_it_directly.py | 2 +- .../{zipformer2.py => zipformer_for_ncnn_export_only.py} | 0 .../pruned_transducer_stateless7_streaming_multi/zipformer2.py | 1 - 12 files changed, 7 insertions(+), 8 deletions(-) delete mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py delete mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py delete mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py rename egs/librispeech/ASR/pruned_transducer_stateless7_streaming/{zipformer2.py => zipformer_for_ncnn_export_only.py} (100%) delete mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 3c13c19c6..0fba3b58f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -66,7 +66,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 61a3f27db..0426bc9a3 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index acde72d80..685f6ece6 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -70,7 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py deleted file mode 120000 index 12dbda888..000000000 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py new file mode 120000 index 000000000..d301e1f9b --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index cd26db6f3..9a6d2155b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -69,7 +69,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer2 import Zipformer +from zipformer_for_ncnn_export_only import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py similarity index 100% rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py deleted file mode 120000 index d3625f478..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/zipformer2.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7_streaming/zipformer2.py \ No newline at end of file From df56aff31ea9b95aa3d9672398a0771dcb8eacc5 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 8 Dec 2023 21:11:31 +0800 Subject: [PATCH 09/46] minor fixes to the vits onnx exportation scripts (#1408) --- egs/ljspeech/TTS/vits/export-onnx.py | 2 +- egs/vctk/TTS/vits/export-onnx.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index bca6aec99..36a9de27f 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -115,8 +115,8 @@ class OnnxModel(nn.Module): tokens: torch.Tensor, tokens_lens: torch.Tensor, noise_scale: float = 0.667, - noise_scale_dur: float = 0.8, alpha: float = 1.0, + noise_scale_dur: float = 0.8, ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of VITS.inference_batch diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index cfc74fd0a..667ac284b 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -121,9 +121,9 @@ class OnnxModel(nn.Module): tokens: torch.Tensor, tokens_lens: torch.Tensor, noise_scale: float = 0.667, + alpha: float = 1.0, noise_scale_dur: float = 0.8, speaker: int = 20, - alpha: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Please see the help information of VITS.inference_batch From b0f70c9d042da734a5df988b98412e5def6b8072 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 10 Dec 2023 11:38:39 +0800 Subject: [PATCH 10/46] Fix torch.jit.script() export for pruned_transducer_stateless2 (#1410) --- egs/librispeech/ASR/pruned_transducer_stateless2/export.py | 2 ++ egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../ASR/pruned_transducer_stateless2/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index e02afa892..e2db98f73 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -49,6 +49,7 @@ from pathlib import Path import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint @@ -198,6 +199,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From 20a82c9abf6664b645c28dd6b3d629cf85e40231 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Dec 2023 18:13:26 +0800 Subject: [PATCH 11/46] first commit (#1411) --- .github/scripts/multi-zh-hans.sh | 88 +++++++++++++++++++ .github/workflows/multi-zh-hans.yml | 84 ++++++++++++++++++ .../ASR/zipformer/export-onnx-streaming.py | 14 ++- 3 files changed, 182 insertions(+), 4 deletions(-) create mode 100755 .github/scripts/multi-zh-hans.sh create mode 100644 .github/workflows/multi-zh-hans.yml diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh new file mode 100755 index 000000000..4ede7f43e --- /dev/null +++ b/.github/scripts/multi-zh-hans.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "pwd: $PWD" + +cd egs/multi_zh-hans/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +cd exp/ +git lfs pull --include pretrained.pt +rm -fv epoch-20.pt +rm -fv *.onnx +ln -s pretrained.pt epoch-20.pt +cd ../data/lang_bpe_2000 +git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model +popd + +log "----------------------------------------" +log "Export streaming ONNX transducer models " +log "----------------------------------------" + +./zipformer/export-onnx-streaming.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 0 + +ls -lh $repo/exp + +log "------------------------------------------------------------" +log "Test export streaming ONNX transducer models (Python code) " +log "------------------------------------------------------------" + +log "test fp32" +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + +log "test int8" +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-20-avg-1-chunk-16-left-128.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav + +log "Upload models to huggingface" +git config --global user.name "k2-fsa" +git config --global user.email "xxx@gmail.com" + +url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $url +dst=$(basename $url) +cp -v $repo/exp/*.onnx $dst +cp -v $repo/data/lang_bpe_2000/tokens.txt $dst +mkdir -p $dst/test_wavs +cp -v $repo/test_wavs/*.wav $dst/test_wavs +cd $dst +git lfs track "*.onnx" +git add . +git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + +log "Upload models to https://github.com/k2-fsa/sherpa-onnx" +rm -rf .git +rm -fv .gitattributes +cd .. +tar cjfv $dst.tar.bz2 $dst +mv -v $dst.tar.bz2 ../../../ diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml new file mode 100644 index 000000000..439300b5f --- /dev/null +++ b/.github/workflows/multi-zh-hans.yml @@ -0,0 +1,84 @@ +name: run-multi-zh-hans + +on: + push: + branches: + - master + - upload-ctc-model + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: run-multi-zh-hans-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + +jobs: + multi-zh-hans: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: export-model + shell: bash + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/multi-zh-hans.sh + ls -lh + + - name: upload model to https://github.com/k2-fsa/sherpa-onnx + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index e2c7d7d95..6bc9b1858 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -614,7 +614,9 @@ def main(): ) logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: @@ -625,7 +627,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ @@ -653,7 +657,8 @@ def main(): filename_start=filename_start, filename_end=filename_end, device=device, - ) + ), + strict=False, ) else: assert params.avg > 0, params.avg @@ -671,7 +676,8 @@ def main(): filename_start=filename_start, filename_end=filename_end, device=device, - ) + ), + strict=False, ) model.to("cpu") From 9e9fe7954d13c5ee8f10e990b358cf8c752a24e6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Dec 2023 18:57:04 +0800 Subject: [PATCH 12/46] Upload gigaspeech zipformer models in CI (#1412) --- .github/scripts/multi-zh-hans.sh | 3 +- .../run-gigaspeech-zipformer-2023-10-17.sh | 74 +++++++++++++++++-- .github/workflows/multi-zh-hans.yml | 5 -- .../run-gigaspeech-zipformer-2023-10-17.yml | 14 ++++ 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index 4ede7f43e..2dd1bce42 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -45,7 +45,7 @@ log "----------------------------------------" ls -lh $repo/exp log "------------------------------------------------------------" -log "Test export streaming ONNX transducer models (Python code) " +log "Test exported streaming ONNX transducer models (Python code)" log "------------------------------------------------------------" log "test fp32" @@ -73,6 +73,7 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $url dst=$(basename $url) cp -v $repo/exp/*.onnx $dst cp -v $repo/data/lang_bpe_2000/tokens.txt $dst +cp -v $repo/data/lang_bpe_2000/bpe.model $dst mkdir -p $dst/test_wavs cp -v $repo/test_wavs/*.wav $dst/test_wavs cd $dst diff --git a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh index 6bb0b9ebc..329896ef6 100755 --- a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh +++ b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh @@ -26,16 +26,80 @@ git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/jit_script.pt" git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt +rm epoch-30.pt +ln -s pretrained.pt epoch-30.pt +rm *.onnx +ls -lh popd +log "----------------------------------------" +log "Export ONNX transducer models " +log "----------------------------------------" + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 30 \ + --avg 1 \ + --exp-dir $repo/exp + +ls -lh $repo/exp + +log "------------------------------------------------------------" +log "Test exported ONNX transducer models (Python code) " +log "------------------------------------------------------------" + +log "test fp32" +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "test int8" +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-30-avg-1.int8.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-30-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-30-avg-1.int8.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Upload models to huggingface" +git config --global user.name "k2-fsa" +git config --global user.email "xxx@gmail.com" + +url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-gigaspeech-2023-12-12 +GIT_LFS_SKIP_SMUDGE=1 git clone $url +dst=$(basename $url) +cp -v $repo/exp/*.onnx $dst +cp -v $repo/data/lang_bpe_500/tokens.txt $dst +cp -v $repo/data/lang_bpe_500/bpe.model $dst +mkdir -p $dst/test_wavs +cp -v $repo/test_wavs/*.wav $dst/test_wavs +cd $dst +git lfs track "*.onnx" +git add . +git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + +log "Upload models to https://github.com/k2-fsa/sherpa-onnx" +rm -rf .git +rm -fv .gitattributes +cd .. +tar cjfv $dst.tar.bz2 $dst +ls -lh +mv -v $dst.tar.bz2 ../../../ + log "Export to torchscript model" ./zipformer/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ + --epoch 30 \ --avg 1 \ --jit 1 @@ -67,7 +131,7 @@ echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt + ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ ls -lh data @@ -83,7 +147,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == ./zipformer/decode.py \ --decoding-method $method \ - --epoch 999 \ + --epoch 30 \ --avg 1 \ --use-averaged-model 0 \ --max-duration $max_duration \ diff --git a/.github/workflows/multi-zh-hans.yml b/.github/workflows/multi-zh-hans.yml index 439300b5f..9081047de 100644 --- a/.github/workflows/multi-zh-hans.yml +++ b/.github/workflows/multi-zh-hans.yml @@ -2,11 +2,6 @@ name: run-multi-zh-hans on: push: - branches: - - master - - upload-ctc-model - - pull_request: branches: - master diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml index 7572f4b5f..87090e310 100644 --- a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -21,6 +21,7 @@ on: push: branches: - master + pull_request: types: [labeled] @@ -33,6 +34,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_gigaspeech_2023_10_17_zipformer-${{ github.ref }} cancel-in-progress: true @@ -85,6 +88,7 @@ jobs: env: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | mkdir -p egs/gigaspeech/ASR/data ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank @@ -97,6 +101,16 @@ jobs: .github/scripts/run-gigaspeech-zipformer-2023-10-17.sh + - name: upload model to https://github.com/k2-fsa/sherpa-onnx + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models + - name: Display decoding results for gigaspeech zipformer if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' shell: bash From d0da509055468f600808a3d7cd8885649d12b96d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 13 Dec 2023 10:33:28 +0800 Subject: [PATCH 13/46] Support ONNX export for Streaming CTC Encoder (#1413) * Create export-onnx-streaming-ctc.py * doc_str updated Co-authored-by: Fangjun Kuang --------- Co-authored-by: Fangjun Kuang --- .../zipformer/export-onnx-streaming-ctc.py | 570 ++++++++++++++++++ 1 file changed, 570 insertions(+) create mode 100755 egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 100755 index 000000000..3c0f74005 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1,570 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang, Zengrui Jin) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a CTC model from PyTorch to ONNX. + + +1. Download the pre-trained streaming model with CTC head + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 \ + --use-ctc 1 + +The --chunk-size in training is "16,32,64,-1", so we select one of them +(excluding -1) during streaming export. The same applies to `--left-context`, +whose value is "64,128,256,-1". + +It will generate the following file inside $repo/exp: + + - ctc-epoch-99-avg-1-chunk-16-left-64.onnx + +See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for Zipformer and the ctc_head""" + + def __init__( + self, + encoder: Zipformer2, + encoder_embed: nn.Module, + ctc_output: nn.Module, + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + ctc_output: + The ctc head. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.ctc_output = ctc_output + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + self.pad_length = 7 + 2 * 3 + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + N = x.size(0) + T = self.chunk_size * 2 + self.pad_length + x_lens = torch.tensor([T] * N, device=x.device) + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=x, + x_lens=x_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) + + src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) + encoder_states = states[:-2] + logging.info(f"len_encoder_states={len(encoder_states)}") + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.ctc_output(encoder_out) + # Now encoder_out is of shape (N, T, ctc_output_dim) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + + return encoder_out, new_states + + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) + states.append(processed_lens) + + return states + + +def export_streaming_ctc_model_onnx( + model: OnnxModel, + encoder_filename: str, + opset_version: int = 11, +) -> None: + model.encoder.__class__.forward = model.encoder.__class__.streaming_forward + + decode_chunk_len = model.chunk_size * 2 + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + T = decode_chunk_len + model.pad_length + + x = torch.rand(1, T, 80, dtype=torch.float32) + init_state = model.get_init_states() + num_encoders = len(model.encoder.encoder_dim) + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + logging.info(f"{name}.shape: {tensors[0].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + logging.info(f"{name}.shape: {tensors[1].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + logging.info(f"{name}.shape: {tensors[2].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + logging.info(f"{name}.shape: {tensors[3].shape}") + inputs[name] = {1: "N"} + outputs[f"new_{name}"] = {1: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + logging.info(f"{name}.shape: {tensors[4].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + logging.info(f"{name}.shape: {tensors[5].shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + num_encoder_layers = ",".join(map(str, model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, model.encoder.encoder_dim)) + cnn_module_kernels = ",".join(map(str, model.encoder.cnn_module_kernel)) + ds = model.encoder.downsampling_factor + left_context_len = model.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + query_head_dims = ",".join(map(str, model.encoder.query_head_dim)) + value_head_dims = ",".join(map(str, model.encoder.value_head_dim)) + num_heads = ",".join(map(str, model.encoder.num_heads)) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "streaming ctc zipformer2", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 32+7+2*3=45 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + "query_head_dims": query_head_dims, + "value_head_dims": value_head_dims, + "num_heads": num_heads, + } + logging.info(f"meta_data: {meta_data}") + + for i in range(len(init_state[:-2]) // 6): + build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + embed_states = init_state[-2] + name = "embed_states" + logging.info(f"{name}.shape: {embed_states.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + # (batch_size,) + processed_lens = init_state[-1] + name = "processed_lens" + logging.info(f"{name}.shape: {processed_lens.shape}") + inputs[name] = {0: "N"} + outputs[f"new_{name}"] = {0: "N"} + input_names.append(name) + output_names.append(f"new_{name}") + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "log_probs": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + + model = OnnxModel( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + ctc_output=model.ctc_output, + ) + + total_num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" + + opset_version = 13 + + logging.info("Exporting model") + model_filename = params.exp_dir / f"ctc-{suffix}.onnx" + export_streaming_ctc_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported model to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"ctc-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() From f85f0252a9f05c10e6782a693c87a1fede2f83c7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 13 Dec 2023 17:34:12 +0800 Subject: [PATCH 14/46] Add greedy search for streaming zipformer CTC. (#1415) --- .github/scripts/multi-zh-hans.sh | 79 +++- .../onnx_pretrained-streaming-ctc.py | 426 ++++++++++++++++++ .../zipformer/onnx_pretrained-streaming.py | 2 +- .../zipformer/export-onnx-streaming-ctc.py | 1 + .../onnx_pretrained-streaming-ctc.py | 1 + 5 files changed, 503 insertions(+), 6 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index 2dd1bce42..427d8887b 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -2,6 +2,10 @@ set -ex +git config --global user.name "k2-fsa" +git config --global user.email "csukuangfj@gmail.com" +git config --global lfs.allowincompletepush true + log() { # This function is from espnet local fname=${BASH_SOURCE[1]##*/} @@ -24,9 +28,73 @@ rm -fv epoch-20.pt rm -fv *.onnx ln -s pretrained.pt epoch-20.pt cd ../data/lang_bpe_2000 +ls -lh git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model +git lfs pull --include "*.model" +ls -lh popd +log "----------------------------------------" +log "Export streaming ONNX CTC models " +log "----------------------------------------" +./zipformer/export-onnx-streaming-ctc.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --use-averaged-model 0 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 + +ls -lh $repo/exp/ + +log "------------------------------------------------------------" +log "Test exported streaming ONNX CTC models (greedy search) " +log "------------------------------------------------------------" + +test_wavs=( +DEV_T0000000000.wav +DEV_T0000000001.wav +DEV_T0000000002.wav +TEST_MEETING_T0000000113.wav +TEST_MEETING_T0000000219.wav +TEST_MEETING_T0000000351.wav +) + +for w in ${test_wavs[@]}; do + ./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/$w +done + +log "Upload onnx CTC models to huggingface" +url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 +GIT_LFS_SKIP_SMUDGE=1 git clone $url +dst=$(basename $url) +cp -v $repo/exp/ctc*.onnx $dst +cp -v $repo/data/lang_bpe_2000/tokens.txt $dst +cp -v $repo/data/lang_bpe_2000/bpe.model $dst +mkdir -p $dst/test_wavs +cp -v $repo/test_wavs/*.wav $dst/test_wavs +cd $dst +git lfs track "*.onnx" "bpe.model" +ls -lh +file bpe.model +git status +git add . +git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true + +log "Upload models to https://github.com/k2-fsa/sherpa-onnx" +rm -rf .git +rm -fv .gitattributes +cd .. +tar cjfv $dst.tar.bz2 $dst +ls -lh *.tar.bz2 +mv -v $dst.tar.bz2 ../../../ + log "----------------------------------------" log "Export streaming ONNX transducer models " log "----------------------------------------" @@ -64,20 +132,20 @@ log "test int8" --tokens $repo/data/lang_bpe_2000/tokens.txt \ $repo/test_wavs/DEV_T0000000000.wav -log "Upload models to huggingface" -git config --global user.name "k2-fsa" -git config --global user.email "xxx@gmail.com" +log "Upload onnx transducer models to huggingface" url=https://huggingface.co/k2-fsa/sherpa-onnx-streaming-zipformer-multi-zh-hans-2023-12-12 GIT_LFS_SKIP_SMUDGE=1 git clone $url dst=$(basename $url) -cp -v $repo/exp/*.onnx $dst +cp -v $repo/exp/encoder*.onnx $dst +cp -v $repo/exp/decoder*.onnx $dst +cp -v $repo/exp/joiner*.onnx $dst cp -v $repo/data/lang_bpe_2000/tokens.txt $dst cp -v $repo/data/lang_bpe_2000/bpe.model $dst mkdir -p $dst/test_wavs cp -v $repo/test_wavs/*.wav $dst/test_wavs cd $dst -git lfs track "*.onnx" +git lfs track "*.onnx" bpe.model git add . git commit -m "upload model" && git push https://k2-fsa:${HF_TOKEN}@huggingface.co/k2-fsa/$dst main || true @@ -86,4 +154,5 @@ rm -rf .git rm -fv .gitattributes cd .. tar cjfv $dst.tar.bz2 $dst +ls -lh *.tar.bz2 mv -v $dst.tar.bz2 ../../../ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 100755 index 000000000..44546cae5 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script loads ONNX models exported by ./export-onnx-streaming-ctc.py +and uses them to decode waves. + +We use the pre-trained model from +https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx-streaming-ctc.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 128 \ + --use-ctc 1 + +It will generate the following 2 files inside $repo/exp: + + - ctc-epoch-99-avg-1-chunk-16-left-128.int8.onnx + - ctc-epoch-99-avg-1-chunk-16-left-128.onnx + +You can use either the ``int8.onnx`` model or just the ``.onnx`` model. + +3. Run this file with the exported ONNX models + +./zipformer/onnx_pretrained-streaming-ctc.py \ + --model-filename $repo/exp/ctc-epoch-99-avg-1-chunk-16-left-128.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000001.wav + +Note: Even though this script only supports decoding a single file, +the exported ONNX models do support batch processing. +""" + +import argparse +import logging +from typing import Dict, List, Tuple + +import numpy as np +import onnxruntime as ort +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +class OnnxModel: + def __init__( + self, + model_filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.init_model(model_filename) + + def init_model(self, model_filename: str): + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + self.init_states() + + def init_states(self, batch_size: int = 1): + meta = self.model.get_modelmeta().custom_metadata_map + logging.info(f"meta={meta}") + + model_type = meta["model_type"] + assert model_type == "zipformer2", model_type + + decode_chunk_len = int(meta["decode_chunk_len"]) + T = int(meta["T"]) + + num_encoder_layers = meta["num_encoder_layers"] + encoder_dims = meta["encoder_dims"] + cnn_module_kernels = meta["cnn_module_kernels"] + left_context_len = meta["left_context_len"] + query_head_dims = meta["query_head_dims"] + value_head_dims = meta["value_head_dims"] + num_heads = meta["num_heads"] + + def to_int_list(s): + return list(map(int, s.split(","))) + + num_encoder_layers = to_int_list(num_encoder_layers) + encoder_dims = to_int_list(encoder_dims) + cnn_module_kernels = to_int_list(cnn_module_kernels) + left_context_len = to_int_list(left_context_len) + query_head_dims = to_int_list(query_head_dims) + value_head_dims = to_int_list(value_head_dims) + num_heads = to_int_list(num_heads) + + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"T: {T}") + logging.info(f"num_encoder_layers: {num_encoder_layers}") + logging.info(f"encoder_dims: {encoder_dims}") + logging.info(f"cnn_module_kernels: {cnn_module_kernels}") + logging.info(f"left_context_len: {left_context_len}") + logging.info(f"query_head_dims: {query_head_dims}") + logging.info(f"value_head_dims: {value_head_dims}") + logging.info(f"num_heads: {num_heads}") + + num_encoders = len(num_encoder_layers) + + self.states = [] + for i in range(num_encoders): + num_layers = num_encoder_layers[i] + key_dim = query_head_dims[i] * num_heads[i] + embed_dim = encoder_dims[i] + nonlin_attn_head_dim = 3 * embed_dim // 4 + value_dim = value_head_dims[i] * num_heads[i] + conv_left_pad = cnn_module_kernels[i] // 2 + + for layer in range(num_layers): + cached_key = torch.zeros( + left_context_len[i], batch_size, key_dim + ).numpy() + cached_nonlin_attn = torch.zeros( + 1, batch_size, left_context_len[i], nonlin_attn_head_dim + ).numpy() + cached_val1 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_val2 = torch.zeros( + left_context_len[i], batch_size, value_dim + ).numpy() + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() + self.states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() + self.states.append(embed_states) + processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() + self.states.append(processed_lens) + + self.num_encoders = num_encoders + + self.segment = T + self.offset = decode_chunk_len + + def _build_model_input_output( + self, + x: torch.Tensor, + ) -> Tuple[Dict[str, np.ndarray], List[str]]: + model_input = {"x": x.numpy()} + model_output = ["log_probs"] + + def build_inputs_outputs(tensors, i): + assert len(tensors) == 6, len(tensors) + + # (downsample_left, batch_size, key_dim) + name = f"cached_key_{i}" + model_input[name] = tensors[0] + model_output.append(f"new_{name}") + + # (1, batch_size, downsample_left, nonlin_attn_head_dim) + name = f"cached_nonlin_attn_{i}" + model_input[name] = tensors[1] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val1_{i}" + model_input[name] = tensors[2] + model_output.append(f"new_{name}") + + # (downsample_left, batch_size, value_dim) + name = f"cached_val2_{i}" + model_input[name] = tensors[3] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv1_{i}" + model_input[name] = tensors[4] + model_output.append(f"new_{name}") + + # (batch_size, embed_dim, conv_left_pad) + name = f"cached_conv2_{i}" + model_input[name] = tensors[5] + model_output.append(f"new_{name}") + + for i in range(len(self.states[:-2]) // 6): + build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) + + # (batch_size, channels, left_pad, freq) + name = "embed_states" + embed_states = self.states[-2] + model_input[name] = embed_states + model_output.append(f"new_{name}") + + # (batch_size,) + name = "processed_lens" + processed_lens = self.states[-1] + model_input[name] = processed_lens + model_output.append(f"new_{name}") + + return model_input, model_output + + def _update_states(self, states: List[np.ndarray]): + self.states = states + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size) + where T' is usually equal to ((T-7)//2 - 3)//2 + """ + model_input, model_output_names = self._build_model_input_output(x) + + out = self.model.run(model_output_names, model_input) + + self._update_states(out[1:]) + + return torch.from_numpy(out[0]) + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + log_probs: torch.Tensor, +) -> List[int]: + """Greedy search for a single utterance. + Args: + log_probs: + A 3-D tensor of shape (1, T, vocab_size) + Returns: + Return the decoded result. + """ + assert log_probs.ndim == 3, log_probs.shape + assert log_probs.shape[0] == 1, log_probs.shape + + max_indexes = log_probs[0].argmax(dim=1) + unique_indexes = torch.unique_consecutive(max_indexes) + + blank_id = 0 + unique_indexes = unique_indexes[unique_indexes != blank_id] + return unique_indexes.tolist() + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + model = OnnxModel(model_filename=args.model_filename) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {args.sound_file}") + waves = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=sample_rate, + )[0] + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + wave_samples = torch.cat([waves, tail_padding]) + + num_processed_frames = 0 + segment = model.segment + offset = model.offset + + hyp = [] + + chunk = int(1 * sample_rate) # 1 second + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + frames = frames.unsqueeze(0) + log_probs = model(frames) + + hyp += greedy_search(log_probs) + + # To handle byte-level BPE, we convert string tokens to utf-8 encoded bytes + id2token = {} + with open(args.tokens, encoding="utf-8") as f: + for line in f: + token, idx = line.split() + if token[:3] == "<0x" and token[-1] == ">": + token = int(token[1:-1], base=16) + assert 0 <= token < 256, token + token = token.to_bytes(1, byteorder="little") + else: + token = token.encode(encoding="utf-8") + + id2token[int(idx)] = token + + text = b"" + for i in hyp: + text += id2token[i] + text = text.decode(encoding="utf-8") + text = text.replace("▁", " ").strip() + + logging.info(args.sound_file) + logging.info(text) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index e62491444..e7c4f40ee 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -326,7 +326,7 @@ class OnnxModel: A 3-D tensor of shape (N, T, C) Returns: Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 + T' is usually equal to ((T-7)//2-3)//2 """ encoder_input, encoder_output_names = self._build_encoder_input_output(x) diff --git a/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 000000000..652346001 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 120000 index 000000000..d623a8462 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file From 10a234709cb6b8aa5e99b1c18140b49db7b6faca Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 14 Dec 2023 11:26:37 +0800 Subject: [PATCH 15/46] bugs fixed (#1416) --- egs/aishell4/ASR/prepare.sh | 2 +- egs/alimeeting/ASR/prepare.sh | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index 361cc26ab..e8d9eb7b9 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -79,7 +79,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Process aishell4" if [ ! -f data/fbank/aishell4/.fbank.done ]; then mkdir -p data/fbank/aishell4 - lhotse prepare aishell4 $dl_dir/aishell4 data/manifests/aishell4 + ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} touch data/fbank/aishell4/.fbank.done fi fi diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 1709733c7..c8fed658d 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -7,6 +7,7 @@ set -eou pipefail stage=-1 stop_stage=100 +perturb_speed=true # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -68,7 +69,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Process alimeeting" if [ ! -f data/fbank/alimeeting/.fbank.done ]; then mkdir -p data/fbank/alimeeting - lhotse prepare ali-meeting $dl_dir/alimeeting data/manifests/alimeeting + ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} fi fi From 702d4f59147a81ef7ba37c7863ae7bd258c743a9 Mon Sep 17 00:00:00 2001 From: TianHao Zhang <32243340+Zth9730@users.noreply.github.com> Date: Thu, 21 Dec 2023 14:42:33 +0800 Subject: [PATCH 16/46] Update prepare.sh (#1422) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix the bug in line 251: 1、 del the additional blank 2、correct the spell error of "new_vocab_size" --- egs/libriheavy/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index af7e3c5b0..b0736c98b 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -248,7 +248,7 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts fi for vocab_size in ${vocab_sizes[@]}; do - new_vacab_size = $(($vocab_size + 256)) + new_vocab_size=$(($vocab_size + 256)) lang_dir=data/lang_punc_bpe_${new_vocab_size} mkdir -p $lang_dir From 79a42148dbcd98c42586f8386d91f6f4bb8f9979 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 23 Dec 2023 00:38:36 +0800 Subject: [PATCH 17/46] Add CI test to cover zipformer/train.py (#1424) --- .github/scripts/docker/Dockerfile | 59 +++++++++++++++ .github/scripts/docker/run.sh | 60 +++++++++++++++ .github/workflows/build-cpu-docker.yml | 75 +++++++++++++++++++ .github/workflows/train-librispeech.yml | 56 ++++++++++++++ egs/gigaspeech/ASR/zipformer/my_profile.py | 1 + egs/gigaspeech/ASR/zipformer/profile.py | 1 - .../{profile.py => my_profile.py} | 2 +- .../{profile.py => my_profile.py} | 2 +- .../{profile.py => my_profile.py} | 2 +- .../zipformer/{profile.py => my_profile.py} | 2 +- egs/tedlium3/ASR/zipformer/my_profile.py | 1 + egs/tedlium3/ASR/zipformer/profile.py | 1 - 12 files changed, 256 insertions(+), 6 deletions(-) create mode 100644 .github/scripts/docker/Dockerfile create mode 100755 .github/scripts/docker/run.sh create mode 100644 .github/workflows/build-cpu-docker.yml create mode 100644 .github/workflows/train-librispeech.yml create mode 120000 egs/gigaspeech/ASR/zipformer/my_profile.py delete mode 120000 egs/gigaspeech/ASR/zipformer/profile.py rename egs/librispeech/ASR/pruned_transducer_stateless/{profile.py => my_profile.py} (98%) rename egs/librispeech/ASR/pruned_transducer_stateless4/{profile.py => my_profile.py} (98%) rename egs/librispeech/ASR/pruned_transducer_stateless7/{profile.py => my_profile.py} (98%) rename egs/librispeech/ASR/zipformer/{profile.py => my_profile.py} (99%) create mode 120000 egs/tedlium3/ASR/zipformer/my_profile.py delete mode 120000 egs/tedlium3/ASR/zipformer/profile.py diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile new file mode 100644 index 000000000..55c3aa1b9 --- /dev/null +++ b/.github/scripts/docker/Dockerfile @@ -0,0 +1,59 @@ +ARG PYTHON_VERSION=3.8 +FROM python:${PYTHON_VERSION} + +ARG TORCHAUDIO_VERSION="0.13.0" +ARG TORCH_VERSION="1.13.0" +ARG K2_VERSION="1.24.4.dev20231220" +ARG KALDIFEAT_VERSION="1.25.3.dev20231221" + +ARG _K2_VERSION="${K2_VERSION}+cpu.torch${TORCH_VERSION}" +ARG _KALDIFEAT_VERSION="${KALDIFEAT_VERSION}+cpu.torch${TORCH_VERSION}" + +RUN apt-get update -y && \ + apt-get install -qq -y \ + ffmpeg \ + git \ + git-lfs \ + less \ + vim \ + && \ + apt-get clean && \ + rm -rf /var/cache/apt/archives /var/lib/apt/lists + + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${_K2_VERSION} +LABEL kaldifeat_version=${_KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +# Install dependencies +RUN pip install --no-cache-dir \ + torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \ + k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + six \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +# RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ +# cd /workspace/icefall && \ +# pip install --no-cache-dir -r requirements.txt +# +# ENV PYTHONPATH /workspace/icefall:$PYTHONPATH +# +# WORKDIR /workspace/icefall diff --git a/.github/scripts/docker/run.sh b/.github/scripts/docker/run.sh new file mode 100755 index 000000000..aeb80b330 --- /dev/null +++ b/.github/scripts/docker/run.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +set -ex + +cd /icefall +export PYTHONPATH=/icefall:$PYTHONPATH +python3 -c "import torch; print(torch.__file__)" +python3 -c "import torchaudio; print(torchaudio.__version__)" +python3 -c "import icefall; print(icefall.__file__)" + +cd egs/librispeech/ASR + +# We don't download the LM file since it is so large that it will +# cause OOM error for CI later. +mkdir -p download/lm +pushd download/lm +wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt +wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt +wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz +ls -lh +gunzip librispeech-lm-norm.txt.gz + +ls -lh +popd + +pushd download/ +wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2 +tar xf LibriSpeech.tar.bz2 +rm LibriSpeech.tar.bz2 + +cd LibriSpeech +ln -s train-clean-100 train-clean-360 +ln -s train-other-500 train-other-500 +popd + +mkdir -p data/manifests + +lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests +ls -lh data/manifests + +./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False +ls -lh data/fbank + +./prepare.sh --stage 5 --stop-stage 6 + +./zipformer/train.py \ + --world-size 1 \ + --num-epochs 1 \ + --start-epoch 1 \ + --use-fp16 0 \ + --exp-dir zipformer/exp-small \ + --causal 0 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 64,96,96,96,96,96 \ + --encoder-dim 32,64,64,64,64,64 \ + --encoder-unmasked-dim 32,32,32,32,32,32 \ + --base-lr 0.04 \ + --full-libri 0 \ + --enable-musan 0 \ + --max-duration 30 \ + --print-diagnostics 1 diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml new file mode 100644 index 000000000..f931f7d09 --- /dev/null +++ b/.github/workflows/build-cpu-docker.yml @@ -0,0 +1,75 @@ +name: build-cpu-docker +on: + workflow_dispatch: + +concurrency: + group: build-cpu-docker-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-cpu-docker: + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + k2-version: ["1.24.4.dev20231220"] + kaldifeat-version: ["1.25.3.dev20231221"] + version: ["1.0"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + + - name: 'Login to GitHub Container Registry' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build docker Image + shell: bash + run: | + cd .github/scripts/docker + torch_version=${{ matrix.torch-version }} + if [[ $torch_version == 1.13.0 ]]; then + torchaudio_version=0.13.0 + elif [[ $torch_version == 2.0.0 ]]; then + torchaudio_version=2.0.1 + elif [[ $torch_version == 2.0.1 ]]; then + torchaudio_version=2.0.2 + else + torchaudio_version=$torch_version + fi + echo "torch_version: $torch_version" + echo "torchaudio_version: $torchaudio_version" + + version=${{ matrix.version }} + + tag=ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version + echo "tag: $tag" + + docker build \ + -t $tag \ + --build-arg PYTHON_VERSION=${{ matrix.python-version }} \ + --build-arg TORCH_VERSION=$torch_version \ + --build-arg TORCHAUDIO_VERSION=$torchaudio_version \ + --build-arg K2_VERSION=${{ matrix.k2-version }} \ + --build-arg KALDIFEAT_VERSION=${{ matrix.kaldifeat-version }} \ + . + + docker image ls + docker push $tag diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml new file mode 100644 index 000000000..7c9a28f03 --- /dev/null +++ b/.github/workflows/train-librispeech.yml @@ -0,0 +1,56 @@ +name: train librispeech +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + +concurrency: + group: train-librispeech-${{ github.ref }} + cancel-in-progress: true + +jobs: + train-librispeech: + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + k2-version: ["1.24.4.dev20231220"] + kaldifeat-version: ["1.25.3.dev20231221"] + version: ["1.0"] + + steps: + # refer to https://github.com/actions/checkout + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run the build process with Docker + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + ls -lh /icefall + + /icefall/.github/scripts/docker/run.sh diff --git a/egs/gigaspeech/ASR/zipformer/my_profile.py b/egs/gigaspeech/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/profile.py b/egs/gigaspeech/ASR/zipformer/profile.py deleted file mode 120000 index c93adbd14..000000000 --- a/egs/gigaspeech/ASR/zipformer/profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless/profile.py rename to egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py index 09e4a7af4..b844ba613 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./pruned_transducer_stateless/profile.py +Usage: ./pruned_transducer_stateless/my_profile.py """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless4/profile.py rename to egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py index 252bdf060..4bf773918 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./pruned_transducer_stateless4/profile.py +Usage: ./pruned_transducer_stateless4/my_profile.py """ import argparse diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/profile.py b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless7/profile.py rename to egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py index 0d308e966..5a068b3b6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/profile.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./pruned_transducer_stateless7/profile.py +Usage: ./pruned_transducer_stateless7/my_profile.py """ import argparse diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/my_profile.py similarity index 99% rename from egs/librispeech/ASR/zipformer/profile.py rename to egs/librispeech/ASR/zipformer/my_profile.py index 57f44a90a..ca20956fb 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/my_profile.py @@ -17,7 +17,7 @@ # limitations under the License. """ -Usage: ./zipformer/profile.py +Usage: ./zipformer/my_profile.py """ import argparse diff --git a/egs/tedlium3/ASR/zipformer/my_profile.py b/egs/tedlium3/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/tedlium3/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/zipformer/profile.py b/egs/tedlium3/ASR/zipformer/profile.py deleted file mode 120000 index c93adbd14..000000000 --- a/egs/tedlium3/ASR/zipformer/profile.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/profile.py \ No newline at end of file From e5bb1ae86cc750a51626e9afcb973cc03fa72f86 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 24 Dec 2023 13:40:33 +0800 Subject: [PATCH 18/46] Use the CPU docker in CI to simplify the test code (#1427) --- .github/scripts/docker/Dockerfile | 21 ++-- .github/workflows/build-cpu-docker.yml | 8 +- .github/workflows/test.yml | 139 +++++++++--------------- .github/workflows/train-librispeech.yml | 8 +- 4 files changed, 72 insertions(+), 104 deletions(-) diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index 55c3aa1b9..bbf978d26 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -14,6 +14,7 @@ RUN apt-get update -y && \ ffmpeg \ git \ git-lfs \ + graphviz \ less \ vim \ && \ @@ -32,23 +33,23 @@ RUN pip install --no-cache-dir \ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ + dill \ + graphviz \ kaldi_native_io \ kaldialign \ kaldifst \ kaldilm \ - sentencepiece>=0.1.96 \ - tensorboard \ - typeguard \ - dill \ - onnx \ - onnxruntime \ - onnxmltools \ - six \ + matplotlib \ multi_quantization \ - typeguard \ numpy \ + onnx \ + onnxmltools \ + onnxruntime \ pytest \ - graphviz + sentencepiece>=0.1.96 \ + six \ + tensorboard \ + typeguard # RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ # cd /workspace/icefall && \ diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml index f931f7d09..b26cd2095 100644 --- a/.github/workflows/build-cpu-docker.yml +++ b/.github/workflows/build-cpu-docker.yml @@ -15,10 +15,10 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] k2-version: ["1.24.4.dev20231220"] kaldifeat-version: ["1.25.3.dev20231221"] - version: ["1.0"] + version: ["1.1"] steps: # refer to https://github.com/actions/checkout @@ -45,8 +45,12 @@ jobs: run: | cd .github/scripts/docker torch_version=${{ matrix.torch-version }} + + # see https://pytorch.org/audio/stable/installation.html#compatibility-matrix if [[ $torch_version == 1.13.0 ]]; then torchaudio_version=0.13.0 + elif [[ $torch_version == 1.13.1 ]]; then + torchaudio_version=0.13.1 elif [[ $torch_version == 2.0.0 ]]; then torchaudio_version=2.0.1 elif [[ $torch_version == 2.0.1 ]]; then diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 363556bb7..b3fd6f133 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,129 +1,94 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - name: test on: push: branches: - master + pull_request: branches: - master + workflow_dispatch: + concurrency: group: test-${{ github.ref }} cancel-in-progress: true jobs: test: + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8"] - torch: ["1.13.0"] - torchaudio: ["0.13.0"] - k2-version: ["1.24.3.dev20230719"] - - fail-fast: false + python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + version: ["1.1"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - - name: Install libnsdfile and libsox - if: startsWith(matrix.os, 'ubuntu') - run: | - sudo apt update - sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all - - - name: Install Python dependencies - run: | - python3 -m pip install --upgrade pip pytest - # numpy 1.20.x does not support python 3.6 - pip install numpy==1.19 - pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - - pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.github.io/k2/cpu.html - pip install git+https://github.com/lhotse-speech/lhotse - # icefall requirements - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - pip install kaldifst - pip install onnxruntime matplotlib - pip install -r requirements.txt - - - name: Install graphviz - if: startsWith(matrix.os, 'ubuntu') + - name: Free space shell: bash run: | - python3 -m pip install -qq graphviz - sudo apt-get -qq install graphviz + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" - name: Run tests - if: startsWith(matrix.os, 'ubuntu') - run: | - ls -lh - export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH - echo $PYTHONPATH - pytest -v -s ./test - # runt tests for conformer ctc - cd egs/librispeech/ASR/conformer_ctc - pytest -v -s + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall - cd ../pruned_transducer_stateless - pytest -v -s + pytest -v -s ./test - cd ../pruned_transducer_stateless2 - pytest -v -s + # runt tests for conformer ctc + cd egs/librispeech/ASR/conformer_ctc + pytest -v -s - cd ../pruned_transducer_stateless3 - pytest -v -s + cd ../pruned_transducer_stateless + pytest -v -s - cd ../pruned_transducer_stateless4 - pytest -v -s + cd ../pruned_transducer_stateless2 + pytest -v -s - echo $PYTHONPATH - cd ../pruned_transducer_stateless7 - pytest -v -s + cd ../pruned_transducer_stateless3 + pytest -v -s - cd ../transducer_stateless - pytest -v -s + cd ../pruned_transducer_stateless4 + pytest -v -s - # cd ../transducer - # pytest -v -s + echo $PYTHONPATH + cd ../pruned_transducer_stateless7 + pytest -v -s - cd ../transducer_stateless2 - pytest -v -s + cd ../transducer_stateless + pytest -v -s - cd ../transducer_lstm - pytest -v -s + # cd ../transducer + # pytest -v -s - cd ../zipformer - pytest -v -s + cd ../transducer_stateless2 + pytest -v -s + + cd ../transducer_lstm + pytest -v -s + + cd ../zipformer + pytest -v -s - uses: actions/upload-artifact@v2 with: diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml index 7c9a28f03..53a2d5843 100644 --- a/.github/workflows/train-librispeech.yml +++ b/.github/workflows/train-librispeech.yml @@ -23,10 +23,8 @@ jobs: matrix: os: [ubuntu-latest] python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - k2-version: ["1.24.4.dev20231220"] - kaldifeat-version: ["1.25.3.dev20231221"] - version: ["1.0"] + torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + version: ["1.1"] steps: # refer to https://github.com/actions/checkout @@ -43,7 +41,7 @@ jobs: echo "pwd: $PWD" echo "github.workspace ${{ github.workspace }}" - - name: Run the build process with Docker + - name: Test zipformer/train.py with LibriSpeech uses: addnab/docker-run-action@v3 with: image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} From c855a58cfd8628dfe4ef2ffc0ac169d84a8ac0c5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Dec 2023 19:41:09 +0800 Subject: [PATCH 19/46] Generate the dependency matrix by code for GitHub Actions (#1431) --- .github/scripts/docker/Dockerfile | 2 + .../scripts/docker/generate_build_matrix.py | 79 ++++++++ .../{docker => librispeech/ASR}/run.sh | 11 +- .github/scripts/yesno/ASR/run.sh | 86 +++++++++ .github/workflows/build-cpu-docker.yml | 42 ++-- .github/workflows/run-yesno-recipe.yml | 182 ++++-------------- .github/workflows/test.yml | 27 ++- .github/workflows/train-librispeech.yml | 33 +++- 8 files changed, 279 insertions(+), 183 deletions(-) create mode 100755 .github/scripts/docker/generate_build_matrix.py rename .github/scripts/{docker => librispeech/ASR}/run.sh (87%) create mode 100755 .github/scripts/yesno/ASR/run.sh diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index bbf978d26..f75d74854 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -31,10 +31,12 @@ LABEL github_repo="https://github.com/k2-fsa/icefall" RUN pip install --no-cache-dir \ torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ + \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ dill \ graphviz \ + kaldi-decoder \ kaldi_native_io \ kaldialign \ kaldifst \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py new file mode 100755 index 000000000..4e494d810 --- /dev/null +++ b/.github/scripts/docker/generate_build_matrix.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + + +import json + + +def version_gt(a, b): + a_major, a_minor = a.split(".")[:2] + b_major, b_minor = b.split(".")[:2] + if a_major > b_major: + return True + + if a_major == b_major and a_minor > b_minor: + return True + + return False + + +def version_ge(a, b): + a_major, a_minor = a.split(".")[:2] + b_major, b_minor = b.split(".")[:2] + if a_major > b_major: + return True + + if a_major == b_major and a_minor >= b_minor: + return True + + return False + + +def get_torchaudio_version(torch_version): + if torch_version == "1.13.0": + return "0.13.0" + elif torch_version == "1.13.1": + return "0.13.1" + elif torch_version == "2.0.0": + return "2.0.1" + elif torch_version == "2.0.1": + return "2.0.2" + else: + return torch_version + + +def get_matrix(): + k2_version = "1.24.4.dev20231220" + kaldifeat_version = "1.25.3.dev20231221" + version = "1.1" + python_version = ["3.8", "3.9", "3.10", "3.11"] + torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] + + matrix = [] + for p in python_version: + for t in torch_version: + # torchaudio <= 1.13.x supports only python <= 3.10 + + if version_gt(p, "3.10") and not version_gt(t, "2.0"): + continue + + matrix.append( + { + "k2-version": k2_version, + "kaldifeat-version": kaldifeat_version, + "version": version, + "python-version": p, + "torch-version": t, + "torchaudio-version": get_torchaudio_version(t), + } + ) + return matrix + + +def main(): + matrix = get_matrix() + print(json.dumps({"include": matrix})) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/docker/run.sh b/.github/scripts/librispeech/ASR/run.sh similarity index 87% rename from .github/scripts/docker/run.sh rename to .github/scripts/librispeech/ASR/run.sh index aeb80b330..641d59458 100755 --- a/.github/scripts/docker/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -1,11 +1,12 @@ #!/usr/bin/env bash + set -ex -cd /icefall -export PYTHONPATH=/icefall:$PYTHONPATH -python3 -c "import torch; print(torch.__file__)" -python3 -c "import torchaudio; print(torchaudio.__version__)" -python3 -c "import icefall; print(icefall.__file__)" +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} cd egs/librispeech/ASR diff --git a/.github/scripts/yesno/ASR/run.sh b/.github/scripts/yesno/ASR/run.sh new file mode 100755 index 000000000..05c8fbac9 --- /dev/null +++ b/.github/scripts/yesno/ASR/run.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/yesno/ASR + +log "data preparation" +./prepare.sh + +log "training" +python3 ./tdnn/train.py + +log "decoding" +python3 ./tdnn/decode.py + +log "export to pretrained.pt" + +python3 ./tdnn/export.py --epoch 14 --avg 2 + +python3 ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test exporting to torchscript" +python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + +python3 ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test exporting to onnx" +python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 + +log "Test float32 model" +python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test int8 model" +python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +log "Test decoding with H" +python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + +python3 ./tdnn/jit_pretrained_decode_with_H.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --H ./data/lang_phone/H.fst \ + --tokens ./data/lang_phone/tokens.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +log "Test decoding with HL" +python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + +python3 ./tdnn/jit_pretrained_decode_with_HL.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HL ./data/lang_phone/HL.fst \ + --words ./data/lang_phone/words.txt \ + ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ + ./download/waves_yesno/0_0_1_0_0_1_1_1.wav + +log "Show generated files" +ls -lh tdnn/exp +ls -lh data/lang_phone diff --git a/.github/workflows/build-cpu-docker.yml b/.github/workflows/build-cpu-docker.yml index b26cd2095..c5d5aaeb6 100644 --- a/.github/workflows/build-cpu-docker.yml +++ b/.github/workflows/build-cpu-docker.yml @@ -7,18 +7,31 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" build-cpu-docker: + needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - k2-version: ["1.24.4.dev20231220"] - kaldifeat-version: ["1.25.3.dev20231221"] - version: ["1.1"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: # refer to https://github.com/actions/checkout @@ -45,25 +58,14 @@ jobs: run: | cd .github/scripts/docker torch_version=${{ matrix.torch-version }} + torchaudio_version=${{ matrix.torchaudio-version }} - # see https://pytorch.org/audio/stable/installation.html#compatibility-matrix - if [[ $torch_version == 1.13.0 ]]; then - torchaudio_version=0.13.0 - elif [[ $torch_version == 1.13.1 ]]; then - torchaudio_version=0.13.1 - elif [[ $torch_version == 2.0.0 ]]; then - torchaudio_version=2.0.1 - elif [[ $torch_version == 2.0.1 ]]; then - torchaudio_version=2.0.2 - else - torchaudio_version=$torch_version - fi echo "torch_version: $torch_version" echo "torchaudio_version: $torchaudio_version" version=${{ matrix.version }} - tag=ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version + tag=ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v$version echo "tag: $tag" docker build \ diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 9ac848535..a99811815 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -20,166 +20,60 @@ on: push: branches: - master + - refactor-ci + pull_request: branches: - master + workflow_dispatch: + concurrency: group: run-yesno-recipe-${{ github.ref }} cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" run-yesno-recipe: - runs-on: ${{ matrix.os }} + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest strategy: - matrix: - # os: [ubuntu-latest, macos-10.15] - # TODO: enable macOS for CPU testing - os: [ubuntu-latest] - python-version: [3.8] fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - name: Run the yesno recipe + uses: addnab/docker-run-action@v3 with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall - - name: Install libnsdfile and libsox - if: startsWith(matrix.os, 'ubuntu') - run: | - sudo apt update - sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - pip install --no-deps --force-reinstall k2==1.24.4.dev20231021+cpu.torch1.13.1 -f https://k2-fsa.github.io/k2/cpu.html - pip install kaldifeat==1.25.1.dev20231022+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - - - name: Run yesno recipe - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - ./prepare.sh - python3 ./tdnn/train.py - python3 ./tdnn/decode.py - - - name: Test exporting to pretrained.pt - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 - - python3 ./tdnn/pretrained.py \ - --checkpoint ./tdnn/exp/pretrained.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - name: Test exporting to torchscript - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 - - python3 ./tdnn/jit_pretrained.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - name: Test exporting to onnx - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 - - echo "Test float32 model" - python3 ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - echo "Test int8 model" - python3 ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - - - name: Test decoding with H - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 - - python3 ./tdnn/jit_pretrained_decode_with_H.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --H ./data/lang_phone/H.fst \ - --tokens ./data/lang_phone/tokens.txt \ - ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ - ./download/waves_yesno/0_0_1_0_0_1_1_1.wav - - - name: Test decoding with HL - shell: bash - working-directory: ${{github.workspace}} - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - echo $PYTHONPATH - - cd egs/yesno/ASR - python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 - - python3 ./tdnn/jit_pretrained_decode_with_HL.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HL ./data/lang_phone/HL.fst \ - --words ./data/lang_phone/words.txt \ - ./download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - ./download/waves_yesno/0_0_1_0_0_0_1_0.wav \ - ./download/waves_yesno/0_0_1_0_0_1_1_1.wav - - - name: Show generated files - shell: bash - working-directory: ${{github.workspace}} - run: | - cd egs/yesno/ASR - ls -lh tdnn/exp - ls -lh data/lang_phone + .github/scripts/yesno/ASR/run.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b3fd6f133..659681b37 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,16 +16,31 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" test: + needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - version: ["1.1"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: - uses: actions/checkout@v4 @@ -44,7 +59,7 @@ jobs: - name: Run tests uses: addnab/docker-run-action@v3 with: - image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} options: | --volume ${{ github.workspace }}/:/icefall shell: bash diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/train-librispeech.yml index 53a2d5843..79002a881 100644 --- a/.github/workflows/train-librispeech.yml +++ b/.github/workflows/train-librispeech.yml @@ -15,16 +15,31 @@ concurrency: cancel-in-progress: true jobs: + generate_build_matrix: + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" train-librispeech: + needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: - os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] - version: ["1.1"] + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} steps: # refer to https://github.com/actions/checkout @@ -44,11 +59,13 @@ jobs: - name: Test zipformer/train.py with LibriSpeech uses: addnab/docker-run-action@v3 with: - image: ghcr.io/k2-fsa/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} options: | --volume ${{ github.workspace }}/:/icefall shell: bash run: | - ls -lh /icefall + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall - /icefall/.github/scripts/docker/run.sh + .github/scripts/librispeech/ASR/run.sh From ddd71313179a1565ea4d9e2e37546c3ef6b98d90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?= <53865510+ahazned@users.noreply.github.com> Date: Mon, 25 Dec 2023 14:44:07 +0300 Subject: [PATCH 20/46] Update TTS export-onnx.py scripts for handling variable token counts (#1430) --- egs/ljspeech/TTS/vits/export-onnx.py | 6 +++++- egs/vctk/TTS/vits/export-onnx.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 36a9de27f..f82f9dbe9 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -149,6 +149,7 @@ class OnnxModel(nn.Module): def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -165,10 +166,12 @@ def export_model_onnx( The VITS generator. model_filename: The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -244,6 +247,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}") diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 667ac284b..80d155626 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -159,6 +159,7 @@ class OnnxModel(nn.Module): def export_model_onnx( model: nn.Module, model_filename: str, + vocab_size: int, opset_version: int = 11, ) -> None: """Export the given generator model to ONNX format. @@ -175,10 +176,12 @@ def export_model_onnx( The VITS generator. model_filename: The filename to save the exported ONNX model. + vocab_size: + Number of tokens used in training. opset_version: The opset version to use. """ - tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) noise_scale = torch.tensor([1], dtype=torch.float32) noise_scale_dur = torch.tensor([1], dtype=torch.float32) @@ -261,6 +264,7 @@ def main(): export_model_onnx( model, model_filename, + params.vocab_size, opset_version=opset_version, ) logging.info(f"Exported generator to {model_filename}") From 835a92eba51a939c6b4a069a53cc1e3ddeabd9a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Dec 2023 20:23:56 +0800 Subject: [PATCH 21/46] Add doc about how to use the CPU-only docker images (#1432) --- docs/source/docker/intro.rst | 46 ++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index 9ead0df00..cbd300d9b 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -20,7 +20,11 @@ We describe the following items in this section: View available tags =================== -You can use the following command to view available tags: +CUDA-enabled docker images +-------------------------- + +You can use the following command to view available tags for CUDA-enabled +docker images: .. code-block:: bash @@ -43,8 +47,25 @@ which will give you something like below: Please select an appropriate combination of `torch`_ and CUDA. -Download a docker image -======================= +CPU-only docker images +---------------------- + +To view CPU-only docker images, please visit ``_ +for available tags. + +You can select different combinations of ``Python`` and ``torch``. For instance, +to select ``Python 3.8`` and ``torch 2.1.2``, you can use the following tag + +.. code-block:: bash + + cpu-py3.8-torch2.1.2-v1.1 + +where ``v1.1`` is the current version of the docker image. You may see +``ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.2`` or some other versions. +We recommend that you always use the latest version. + +Download a docker image (CUDA) +============================== Suppose that you select the tag ``torch1.13.0-cuda11.6``, you can use the following command to download it: @@ -53,6 +74,16 @@ the following command to download it: sudo docker image pull k2fsa/icefall:torch1.13.0-cuda11.6 +Download a docker image (CPU) +============================== + +Suppose that you select the tag ``cpu-py3.8-torch2.1.2-v1.1``, you can use +the following command to download it: + +.. code-block:: bash + + sudo docker pull ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.1 + Run a docker image with GPU =========================== @@ -65,7 +96,7 @@ Run a docker image with CPU .. code-block:: bash - sudo docker run --rm -it k2fsa/icefall:torch1.13.0-cuda11.6 /bin/bash + sudo docker run --rm -it ghcr.io/k2-fsa/icefall:cpu-py3.8-torch2.1.2-v1.1 /bin/bash Run yesno within a docker container =================================== @@ -74,8 +105,13 @@ After starting the container, the following interface is presented: .. code-block:: bash + # GPU-enabled docker root@60c947eac59c:/workspace/icefall# + # CPU-only docker + root@60c947eac59c:# mkdir /workspace; git clone https://github.com/k2-fsa/icefall + root@60c947eac59c:# export PYTHONPATH=/workspace/icefall:$PYTHONPATH + It shows the current user is ``root`` and the current working directory is ``/workspace/icefall``. @@ -107,7 +143,7 @@ to switch to the ``yesno`` recipe and run .. hint:: - If you are running without GPU, it may report the following error: + If you are running without GPU with a GPU-enabled docker, it may report the following error: .. code-block:: bash From db52fe2349df0e07e931accb0cf1e63fec389fb7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Dec 2023 20:29:43 +0800 Subject: [PATCH 22/46] Refactor CI test for aishell (#1435) --- .github/scripts/aishell/ASR/run.sh | 274 ++++++++++++++++++ .github/scripts/docker/Dockerfile | 1 + .../scripts/docker/generate_build_matrix.py | 2 +- ...pruned-transducer-stateless3-2022-06-20.sh | 87 ------ .../run-aishell-zipformer-2023-10-24.sh | 103 ------- ...transducer-stateless-modified-2-aishell.sh | 48 --- ...d-transducer-stateless-modified-aishell.sh | 48 --- .github/workflows/aishell.yml | 81 ++++++ .github/workflows/run-aishell-2022-06-20.yml | 123 -------- .../run-aishell-zipformer-2023-10-24.yml | 95 ------ ...ransducer-stateless-modified-2-aishell.yml | 80 ----- ...-transducer-stateless-modified-aishell.yml | 80 ----- .../{run-yesno-recipe.yml => yesno.yml} | 23 +- 13 files changed, 360 insertions(+), 685 deletions(-) create mode 100755 .github/scripts/aishell/ASR/run.sh delete mode 100755 .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh delete mode 100755 .github/scripts/run-aishell-zipformer-2023-10-24.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh create mode 100644 .github/workflows/aishell.yml delete mode 100644 .github/workflows/run-aishell-2022-06-20.yml delete mode 100644 .github/workflows/run-aishell-zipformer-2023-10-24.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml rename .github/workflows/{run-yesno-recipe.yml => yesno.yml} (69%) diff --git a/.github/scripts/aishell/ASR/run.sh b/.github/scripts/aishell/ASR/run.sh new file mode 100755 index 000000000..4d912fa76 --- /dev/null +++ b/.github/scripts/aishell/ASR/run.sh @@ -0,0 +1,274 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +function download_test_dev_manifests() { + git lfs install + + fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests + log "Downloading pre-commputed fbank from $fbank_url" + + git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests + ln -s $PWD/aishell-test-dev-manifests/data . +} + +function test_transducer_stateless3_2022_06_20() { + repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt + popd + + log "test greedy_search with pretrained.py" + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + log "test beam search with pretrained.py" + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" + echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" + if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless3/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_char data/ + + ls -lh data + ls -lh pruned_transducer_stateless3/exp + + log "Decoding test and dev" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless3/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless3/exp + done + + rm pruned_transducer_stateless3/exp/*.pt + fi + + rm -rf $repo +} + +function test_zipformer_large_2023_10_24() { + log "CI testing large model" + repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/ + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_zipformer_2023_10_24() { + repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/ + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + + for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_zipformer_small_2023_10_24() { + log "CI testing small model" + repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/ + log "Downloading pre-trained model from $repo_url" + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + + for method in modified_beam_search greedy_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --context-size 1 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_char/tokens.txt \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_transducer_stateless_modified_2022_03_01() { + repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_modified/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_modified/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +function test_transducer_stateless_modified_2_2022_03_01() { + repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_modified-2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + + for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_modified-2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav + done + rm -rf $repo +} + +download_test_dev_manifests +test_transducer_stateless3_2022_06_20 +test_zipformer_large_2023_10_24 +test_zipformer_2023_10_24 +test_zipformer_small_2023_10_24 +test_transducer_stateless_modified_2022_03_01 +test_transducer_stateless_modified_2_2022_03_01 + +ls -lh diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index f75d74854..f6a088af1 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -16,6 +16,7 @@ RUN apt-get update -y && \ git-lfs \ graphviz \ less \ + tree \ vim \ && \ apt-get clean && \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 4e494d810..bdde97647 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20231220" kaldifeat_version = "1.25.3.dev20231221" - version = "1.1" + version = "1.2" python_version = ["3.8", "3.9", "3.10", "3.11"] torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"] diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh deleted file mode 100755 index c3640cfde..000000000 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -git lfs install - -fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests -log "Downloading pre-commputed fbank from $fbank_url" - -git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests -ln -s $PWD/aishell-test-dev-manifests/data . - -repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless3/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless3/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless3/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_char data/ - - ls -lh data - ls -lh pruned_transducer_stateless3/exp - - log "Decoding test and dev" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless3/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless3/exp - done - - rm pruned_transducer_stateless3/exp/*.pt -fi diff --git a/.github/scripts/run-aishell-zipformer-2023-10-24.sh b/.github/scripts/run-aishell-zipformer-2023-10-24.sh deleted file mode 100755 index 865e29799..000000000 --- a/.github/scripts/run-aishell-zipformer-2023-10-24.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -git lfs install - -fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests -log "Downloading pre-commputed fbank from $fbank_url" - -git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests -ln -s $PWD/aishell-test-dev-manifests/data . - -log "=======================" -log "CI testing large model" -repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-large-2023-10-24/ -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for method in modified_beam_search greedy_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --context-size 1 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_char/tokens.txt \ - --num-encoder-layers 2,2,4,5,4,2 \ - --feedforward-dim 512,768,1536,2048,1536,768 \ - --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -log "=======================" -log "CI testing medium model" -repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-2023-10-24/ -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - - -for method in modified_beam_search greedy_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --context-size 1 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - - -log "=======================" -log "CI testing small model" -repo_url=https://huggingface.co/zrjin/icefall-asr-aishell-zipformer-small-2023-10-24/ -log "Downloading pre-trained model from $repo_url" -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - - -for method in modified_beam_search greedy_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --context-size 1 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_char/tokens.txt \ - --num-encoder-layers 2,2,2,2,2,2 \ - --feedforward-dim 512,768,768,768,768,768 \ - --encoder-dim 192,256,256,256,256,256 \ - --encoder-unmasked-dim 192,192,192,192,192,192 \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh deleted file mode 100755 index 0644d9be0..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_modified-2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -for method in modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless_modified-2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh deleted file mode 100755 index 79fb64311..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/aishell/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_modified/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done - -for method in modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless_modified/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_char \ - $repo/test_wavs/BAC009S0764W0121.wav \ - $repo/test_wavs/BAC009S0764W0122.wav \ - $repo/test_wavs/BAC009S0764W0123.wav -done diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml new file mode 100644 index 000000000..136e117bd --- /dev/null +++ b/.github/workflows/aishell.yml @@ -0,0 +1,81 @@ +name: aishell + +on: + push: + branches: + - master + + pull_request: + branches: + - master + + workflow_dispatch: + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +concurrency: + group: aishell-${{ github.ref }} + cancel-in-progress: true + +jobs: + generate_build_matrix: + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'aishell' || github.event_name == 'schedule') + + # see https://github.com/pytorch/pytorch/pull/50633 + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Generating build matrix + id: set-matrix + run: | + # outputting for debugging purposes + python ./.github/scripts/docker/generate_build_matrix.py + MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) + echo "::set-output name=matrix::${MATRIX}" + aishell: + needs: generate_build_matrix + name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + ${{ fromJson(needs.generate_build_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Free space + shell: bash + run: | + df -h + rm -rf /opt/hostedtoolcache + df -h + echo "pwd: $PWD" + echo "github.workspace ${{ github.workspace }}" + + - name: Run aishell tests + uses: addnab/docker-run-action@v3 + with: + image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }} + options: | + --volume ${{ github.workspace }}/:/icefall + shell: bash + run: | + export PYTHONPATH=/icefall:$PYTHONPATH + cd /icefall + git config --global --add safe.directory /icefall + + .github/scripts/aishell/ASR/run.sh diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml deleted file mode 100644 index 53fcb2c03..000000000 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-aishell-2022-06-20 -# pruned RNN-T + reworked model with random combiner -# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_aishell_2022_06_20-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_aishell_2022_06_20: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh - - - name: Display decoding results for aishell pruned_transducer_stateless3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/aishell/ASR/ - tree ./pruned_transducer_stateless3/exp - - cd pruned_transducer_stateless3 - echo "results for pruned_transducer_stateless3" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 - - - name: Upload decoding results for aishell pruned_transducer_stateless3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20 - path: egs/aishell/ASR/pruned_transducer_stateless3/exp/ diff --git a/.github/workflows/run-aishell-zipformer-2023-10-24.yml b/.github/workflows/run-aishell-zipformer-2023-10-24.yml deleted file mode 100644 index f2fb44a5f..000000000 --- a/.github/workflows/run-aishell-zipformer-2023-10-24.yml +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 Zengrui Jin (Xiaomi Corp.) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-aishell-zipformer-2023-10-24 - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_aishell_zipformer_2023_10_24-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_aishell_zipformer_2023_10_24: - if: github.event.label.name == 'ready' || github.event.label.name == 'zipformer' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-aishell-zipformer-2023-10-24.sh - - \ No newline at end of file diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml deleted file mode 100644 index ce6d6f92d..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-modified-2-aishell - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -concurrency: - group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_modified_2_aishell: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml deleted file mode 100644 index f0cebd94a..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-modified-aishell - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -concurrency: - group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_modified_aishell: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/yesno.yml similarity index 69% rename from .github/workflows/run-yesno-recipe.yml rename to .github/workflows/yesno.yml index a99811815..182300dfa 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/yesno.yml @@ -1,26 +1,9 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-yesno-recipe +name: yesno on: push: branches: - master - - refactor-ci pull_request: branches: @@ -29,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: run-yesno-recipe-${{ github.ref }} + group: yesno-${{ github.ref }} cancel-in-progress: true jobs: @@ -50,7 +33,7 @@ jobs: python ./.github/scripts/docker/generate_build_matrix.py MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) echo "::set-output name=matrix::${MATRIX}" - run-yesno-recipe: + yesno: needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ubuntu-latest From 140e6381ad4699ce919705f59240fad58c0b1bb6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 27 Dec 2023 13:21:14 +0800 Subject: [PATCH 23/46] Refactor CI tests for librispeech (#1436) --- .github/scripts/aishell/ASR/run.sh | 73 +- .github/scripts/librispeech/ASR/run.sh | 1624 ++++++++++++++++- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 122 -- ...-pruned-transducer-stateless-2022-03-12.sh | 77 - ...pruned-transducer-stateless2-2022-04-29.sh | 86 - ...pruned-transducer-stateless3-2022-04-29.sh | 85 - ...pruned-transducer-stateless3-2022-05-13.sh | 123 -- ...pruned-transducer-stateless5-2022-05-13.sh | 100 - ...pruned-transducer-stateless7-2022-11-11.sh | 106 -- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 150 -- ...transducer-stateless7-ctc-bs-2023-01-29.sh | 147 -- ...nsducer-stateless7-streaming-2022-12-29.sh | 148 -- ...pruned-transducer-stateless8-2022-11-14.sh | 115 -- ...pruned-transducer-stateless2-2022-06-26.sh | 101 - ...rispeech-streaming-zipformer-2023-05-18.sh | 116 -- ...speech-transducer-stateless2-2022-04-19.sh | 77 - .../run-librispeech-zipformer-2023-05-18.sh | 94 - ...un-librispeech-zipformer-ctc-2023-06-14.sh | 117 -- ...un-librispeech-zipformer-mmi-2022-12-08.sh | 102 -- .github/scripts/run-pre-trained-ctc.sh | 240 --- ...d-transducer-stateless-librispeech-100h.sh | 77 - ...d-transducer-stateless-librispeech-960h.sh | 77 - .../run-pre-trained-transducer-stateless.sh | 77 - .github/scripts/run-pre-trained-transducer.sh | 33 - .github/workflows/aishell.yml | 11 +- ...{train-librispeech.yml => librispeech.yml} | 6 +- .../workflows/run-librispeech-2022-03-12.yml | 159 -- .../workflows/run-librispeech-2022-04-29.yml | 185 -- .../workflows/run-librispeech-2022-05-13.yml | 159 -- .../run-librispeech-2022-11-11-stateless7.yml | 159 -- .../run-librispeech-2022-11-14-stateless8.yml | 159 -- ...-librispeech-2022-12-01-stateless7-ctc.yml | 163 -- ...n-librispeech-2022-12-08-zipformer-mmi.yml | 167 -- ...speech-2022-12-29-stateless7-streaming.yml | 172 -- ...brispeech-2023-01-29-stateless7-ctc-bs.yml | 163 -- ...-librispeech-conformer-ctc3-2022-11-28.yml | 155 -- ...runed-transducer-stateless3-2022-05-13.yml | 157 -- ...aming-transducer-stateless2-2022-06-26.yml | 159 -- ...ispeech-streaming-zipformer-2023-05-18.yml | 174 -- ...peech-transducer-stateless2-2022-04-19.yml | 159 -- .../run-librispeech-zipformer-2023-05-18.yml | 159 -- ...n-librispeech-zipformer-ctc-2023-06-14.yml | 155 -- .github/workflows/run-pretrained-ctc.yml | 87 - ...-transducer-stateless-librispeech-100h.yml | 158 -- ...r-stateless-librispeech-multi-datasets.yml | 158 -- .../run-pretrained-transducer-stateless.yml | 158 -- .../workflows/run-pretrained-transducer.yml | 80 - 47 files changed, 1658 insertions(+), 5671 deletions(-) delete mode 100755 .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh delete mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh delete mode 100755 .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh delete mode 100755 .github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh delete mode 100755 .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh delete mode 100755 .github/scripts/run-librispeech-zipformer-2023-05-18.sh delete mode 100755 .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh delete mode 100755 .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh delete mode 100755 .github/scripts/run-pre-trained-ctc.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh delete mode 100755 .github/scripts/run-pre-trained-transducer-stateless.sh delete mode 100755 .github/scripts/run-pre-trained-transducer.sh rename .github/workflows/{train-librispeech.yml => librispeech.yml} (95%) delete mode 100644 .github/workflows/run-librispeech-2022-03-12.yml delete mode 100644 .github/workflows/run-librispeech-2022-04-29.yml delete mode 100644 .github/workflows/run-librispeech-2022-05-13.yml delete mode 100644 .github/workflows/run-librispeech-2022-11-11-stateless7.yml delete mode 100644 .github/workflows/run-librispeech-2022-11-14-stateless8.yml delete mode 100644 .github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml delete mode 100644 .github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml delete mode 100644 .github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml delete mode 100644 .github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml delete mode 100644 .github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml delete mode 100644 .github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml delete mode 100644 .github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml delete mode 100644 .github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml delete mode 100644 .github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml delete mode 100644 .github/workflows/run-librispeech-zipformer-2023-05-18.yml delete mode 100644 .github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml delete mode 100644 .github/workflows/run-pretrained-ctc.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml delete mode 100644 .github/workflows/run-pretrained-transducer-stateless.yml delete mode 100644 .github/workflows/run-pretrained-transducer.yml diff --git a/.github/scripts/aishell/ASR/run.sh b/.github/scripts/aishell/ASR/run.sh index 4d912fa76..f150b6337 100755 --- a/.github/scripts/aishell/ASR/run.sh +++ b/.github/scripts/aishell/ASR/run.sh @@ -263,6 +263,76 @@ function test_transducer_stateless_modified_2_2022_03_01() { rm -rf $repo } +function test_conformer_ctc() { + repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_char/H.fst" + git lfs pull --include "data/lang_char/HL.fst" + git lfs pull --include "data/lang_char/HLG.fst" + + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "CTC decoding" + + log "Exporting model with torchscript" + + pushd $repo/exp + ln -s pretrained.pt epoch-99.pt + popd + + ./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_char/tokens.txt \ + --jit 1 + + ls -lh $repo/exp + + ls -lh $repo/data/lang_char + + log "Decoding with H on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_char/H.fst \ + --tokens $repo/data/lang_char/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "Decoding with HL on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_char/HL.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "Decoding with HLG on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_char/HLG.fst \ + --words $repo/data/lang_char/words.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + rm -rf $repo +} + download_test_dev_manifests test_transducer_stateless3_2022_06_20 test_zipformer_large_2023_10_24 @@ -270,5 +340,4 @@ test_zipformer_2023_10_24 test_zipformer_small_2023_10_24 test_transducer_stateless_modified_2022_03_01 test_transducer_stateless_modified_2_2022_03_01 - -ls -lh +# test_conformer_ctc # fails for torch 1.13.x and torch 2.0.x diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh index 641d59458..7e9bd8a47 100755 --- a/.github/scripts/librispeech/ASR/run.sh +++ b/.github/scripts/librispeech/ASR/run.sh @@ -10,52 +10,1594 @@ log() { cd egs/librispeech/ASR -# We don't download the LM file since it is so large that it will -# cause OOM error for CI later. -mkdir -p download/lm -pushd download/lm -wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt -wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt -wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -ls -lh -gunzip librispeech-lm-norm.txt.gz +function prepare_data() { + # We don't download the LM file since it is so large that it will + # cause OOM error for CI later. + mkdir -p download/lm + pushd download/lm + wget -q http://www.openslr.org/resources/11/librispeech-vocab.txt + wget -q http://www.openslr.org/resources/11/librispeech-lexicon.txt + wget -q http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + ls -lh + gunzip librispeech-lm-norm.txt.gz -ls -lh -popd + ls -lh + popd -pushd download/ -wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2 -tar xf LibriSpeech.tar.bz2 -rm LibriSpeech.tar.bz2 + pushd download/ + wget -q https://huggingface.co/csukuangfj/librispeech-for-ci/resolve/main/LibriSpeech.tar.bz2 + tar xf LibriSpeech.tar.bz2 + rm LibriSpeech.tar.bz2 -cd LibriSpeech -ln -s train-clean-100 train-clean-360 -ln -s train-other-500 train-other-500 -popd + cd LibriSpeech + ln -s train-clean-100 train-clean-360 + ln -s train-other-500 train-other-500 + popd -mkdir -p data/manifests + mkdir -p data/manifests -lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests -ls -lh data/manifests + lhotse prepare librispeech -j 2 -p dev-clean -p dev-other -p test-clean -p test-other -p train-clean-100 download/LibriSpeech data/manifests + ls -lh data/manifests -./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False -ls -lh data/fbank + ./local/compute_fbank_librispeech.py --dataset "dev-clean dev-other test-clean test-other train-clean-100" --perturb-speed False + ls -lh data/fbank -./prepare.sh --stage 5 --stop-stage 6 + ./prepare.sh --stage 5 --stop-stage 6 +} -./zipformer/train.py \ - --world-size 1 \ - --num-epochs 1 \ - --start-epoch 1 \ - --use-fp16 0 \ - --exp-dir zipformer/exp-small \ - --causal 0 \ - --num-encoder-layers 1,1,1,1,1,1 \ - --feedforward-dim 64,96,96,96,96,96 \ - --encoder-dim 32,64,64,64,64,64 \ - --encoder-unmasked-dim 32,32,32,32,32,32 \ - --base-lr 0.04 \ - --full-libri 0 \ - --enable-musan 0 \ - --max-duration 30 \ - --print-diagnostics 1 +function run_diagnostics() { + ./zipformer/train.py \ + --world-size 1 \ + --num-epochs 1 \ + --start-epoch 1 \ + --use-fp16 0 \ + --exp-dir zipformer/exp-small \ + --causal 0 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 64,96,96,96,96,96 \ + --encoder-dim 32,64,64,64,64,64 \ + --encoder-unmasked-dim 32,32,32,32,32,32 \ + --base-lr 0.04 \ + --full-libri 0 \ + --enable-musan 0 \ + --max-duration 30 \ + --print-diagnostics 1 +} + +function test_pruned_transducer_stateless_2022_03_12() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in fast_beam_search modified_beam_search beam_search; do + log "$method" + + ./pruned_transducer_stateless/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless2_2022_04_29() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt" + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-38-avg-10.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless3_2022_04_29() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt" + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-25-avg-6.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless5_2022_05_13() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-39-avg-7.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless5/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless5/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + done + rm -rf $repo +} + +function test_pruned_transducer_stateless7_2022_11_11() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless8_2022_11_14() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Export to torchscript model" + ./pruned_transducer_stateless8/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model false \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless8/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless8/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless8/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless7_ctc_2022_12_01() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lm/G_4_gram.pt" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7_ctc/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_mmi_2022_12_08() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/3gram.pt" + git lfs pull --include "data/lang_bpe_500/4gram.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer_mmi/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./zipformer_mmi/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --lang-dir $repo/data/lang_bpe_500 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do + log "$method" + + ./zipformer_mmi/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_bpe_500 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless7_streaming_2022_12_29() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "exp/encoder_jit_trace.pt" + git lfs pull --include "exp/decoder_jit_trace.pt" + git lfs pull --include "exp/joiner_jit_trace.pt" + cd exp + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7_streaming/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Export to torchscript model by torch.jit.trace()" + ./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --decode-chunk-len 32 \ + --epoch 99 \ + --avg 1 + + log "Decode with models exported by torch.jit.trace()" + + ./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_streaming/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --decode-chunk-len 32 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + rm -rf $repo +} + +function test_pruned_transducer_stateless7_ctc_bs_2023_01_29() { + repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "exp/cpu_jit.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./pruned_transducer_stateless7_ctc_bs/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_conformer_ctc3_2022_11_27() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27 + + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lm/G_4_gram.pt" + git lfs pull --include "exp/jit_trace.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Decode with models exported by torch.jit.trace()" + + for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + log "Export to torchscript model" + + ./conformer_ctc3/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit-trace 1 \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.trace()" + + for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for m in ctc-decoding 1best; do + ./conformer_ctc3/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_lstm_transducer_stateless2_2022_09_03() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + abs_repo=$(realpath $repo) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-iter-468000-avg-16.pt pretrained.pt + ln -s pretrained-iter-468000-avg-16.pt epoch-99.pt + popd + + log "Test exporting with torch.jit.trace()" + + ./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --jit-trace 1 + + log "Decode with models exported by torch.jit.trace()" + + ./lstm_transducer_stateless2/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./lstm_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./lstm_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_pruned_transducer_stateless3_2022_05_13() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt + ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt + popd + + + log "Export to torchscript model" + ./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ./pruned_transducer_stateless3/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit-trace 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.trace()" + + ./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ + --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ + --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Decode with models exported by torch.jit.script()" + + ./pruned_transducer_stateless3/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo/exp/encoder_jit_script.pt \ + --decoder-model-filename $repo/exp/decoder_jit_script.pt \ + --joiner-model-filename $repo/exp/joiner_jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + rm -rf $repo +} + +function test_streaming_pruned_transducer_stateless2_20220625() { + repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + ln -s pretrained-epoch-24-avg-10.pt pretrained.pt + popd + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_streaming_zipformer_2023_05_17() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "exp/jit_script_chunk_16_left_128.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./zipformer/jit_pretrained_streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \ + $repo/test_wavs/1089-134686-0001.wav + + for method in greedy_search modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_2023_05_18() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "exp/jit_script.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + ./zipformer/jit_pretrained.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --nn-model-filename $repo/exp/jit_script.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + for method in greedy_search modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_transducer_stateless2_torchaudio_2022_04_19() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in fast_beam_search modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_transducer_ctc_2023_06_13() { + repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 + + log "Downloading pre-trained model from $repo_url" + git lfs install + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + pushd $repo/exp + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/LG.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lm/G_4_gram.pt" + git lfs pull --include "exp/jit_script.pt" + git lfs pull --include "exp/pretrained.pt" + ln -s pretrained.pt epoch-99.pt + ls -lh *.pt + popd + + log "Export to torchscript model" + ./zipformer/export.py \ + --exp-dir $repo/exp \ + --use-transducer 1 \ + --use-ctc 1 \ + --use-averaged-model false \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + + ls -lh $repo/exp/*.pt + + log "Decode with models exported by torch.jit.script()" + + for method in ctc-decoding 1best; do + ./zipformer/jit_pretrained_ctc.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --model-filename $repo/exp/jit_script.pt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --G $repo/data/lm/G_4_gram.pt \ + --method $method \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in ctc-decoding 1best; do + log "$method" + + ./zipformer/pretrained_ctc.py \ + --use-transducer 1 \ + --use-ctc 1 \ + --method $method \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_100h_transducer_stateless_multi_datasets_bpe_500_2022_02_21() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_transducer_stateless_multi_datasets_bpe_500_2022_03_01() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_transducer_stateless_bpe_500_2022_02_07() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + + for method in fast_beam_search modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + done + rm -rf $repo +} + +function test_zipformer_ctc_en_2023_10_02() { + repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "CTC greedy search" + + ./zipformer/onnx_pretrained_ctc.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "CTC H decoding" + + ./zipformer/onnx_pretrained_ctc_H.py \ + --nn-model $repo/model.onnx \ + --tokens $repo/tokens.txt \ + --H $repo/H.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "CTC HL decoding" + + ./zipformer/onnx_pretrained_ctc_HL.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HL $repo/HL.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "CTC HLG decoding" + + ./zipformer/onnx_pretrained_ctc_HLG.py \ + --nn-model $repo/model.onnx \ + --words $repo/words.txt \ + --HLG $repo/HLG.fst \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + rm -rf $repo +} + +function test_conformer_ctc_jit_bpe_500_2021_11_09() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 + log "Downloading pre-trained model from $repo_url" + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + repo=$(basename $repo_url) + pushd $repo + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_bpe_500/HLG.pt" + git lfs pull --include "data/lang_bpe_500/L.pt" + git lfs pull --include "data/lang_bpe_500/L_disambig.pt" + git lfs pull --include "data/lang_bpe_500/Linv.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + git lfs pull --include "data/lang_bpe_500/lexicon.txt" + git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" + git lfs pull --include "data/lang_bpe_500/tokens.txt" + git lfs pull --include "data/lang_bpe_500/words.txt" + git lfs pull --include "data/lm/G_3_gram.fst.txt" + + popd + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "CTC decoding" + + ./conformer_ctc/pretrained.py \ + --method ctc-decoding \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "HLG decoding" + + ./conformer_ctc/pretrained.py \ + --method 1best \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "CTC decoding on CPU with kaldi decoders using OpenFst" + + log "Exporting model with torchscript" + + pushd $repo/exp + ln -s pretrained.pt epoch-99.pt + popd + + ./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit 1 + + ls -lh $repo/exp + + + log "Generating H.fst, HL.fst" + + ./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt + + ls -lh $repo/data/lang_bpe_500 + + log "Decoding with H on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_bpe_500/H.fst \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Decoding with HL on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_bpe_500/HL.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Decoding with HLG on CPU with OpenFst" + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + rm -rf $repo +} + +function test_transducer_bpe_500_2021_12_23() { + repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 + + log "Downloading pre-trained model from $repo_url" + git lfs install + git clone $repo_url + repo=$(basename $repo_url) + + log "Display test files" + tree $repo/ + ls -lh $repo/test_wavs/*.wav + + log "Beam search decoding" + + ./transducer/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + rm -rf $repo +} + +prepare_data +run_diagnostics +test_pruned_transducer_stateless_2022_03_12 +test_pruned_transducer_stateless2_2022_04_29 +test_pruned_transducer_stateless3_2022_04_29 +test_pruned_transducer_stateless5_2022_05_13 +test_pruned_transducer_stateless7_2022_11_11 +test_pruned_transducer_stateless8_2022_11_14 +test_pruned_transducer_stateless7_ctc_2022_12_01 +test_zipformer_mmi_2022_12_08 +test_pruned_transducer_stateless7_streaming_2022_12_29 +test_pruned_transducer_stateless7_ctc_bs_2023_01_29 +test_conformer_ctc3_2022_11_27 +test_lstm_transducer_stateless2_2022_09_03 +test_pruned_transducer_stateless3_2022_05_13 +test_streaming_pruned_transducer_stateless2_20220625 +test_streaming_zipformer_2023_05_17 +test_zipformer_2023_05_18 +test_transducer_stateless2_torchaudio_2022_04_19 +test_zipformer_transducer_ctc_2023_06_13 +test_100h_transducer_stateless_multi_datasets_bpe_500_2022_02_21 +test_transducer_stateless_multi_datasets_bpe_500_2022_03_01 +test_transducer_stateless_bpe_500_2022_02_07 +test_zipformer_ctc_en_2023_10_02 +# test_conformer_ctc_jit_bpe_500_2021_11_09 # failes for torch != 1.13.x and torch != 2.0.x +test_transducer_bpe_500_2021_12_23 diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh deleted file mode 100755 index f6fe8c9b2..000000000 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lm/G_4_gram.pt" -git lfs pull --include "exp/jit_trace.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Decode with models exported by torch.jit.trace()" - -for m in ctc-decoding 1best; do - ./conformer_ctc3/jit_pretrained.py \ - --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -log "Export to torchscript model" - -./conformer_ctc3/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --jit-trace 1 \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.trace()" - -for m in ctc-decoding 1best; do - ./conformer_ctc3/jit_pretrained.py \ - --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for m in ctc-decoding 1best; do - ./conformer_ctc3/pretrained.py \ - --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p conformer_ctc3/exp - ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh conformer_ctc3/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in ctc-decoding 1best; do - log "Decoding with $method" - ./conformer_ctc3/decode.py \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir conformer_ctc3/exp/ \ - --max-duration $max_duration \ - --decoding-method $method \ - --lm-dir data/lm - done - - rm conformer_ctc3/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh deleted file mode 100755 index 412e3ad56..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in fast_beam_search modified_beam_search beam_search; do - log "$method" - - ./pruned_transducer_stateless/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless/exp - done - - rm pruned_transducer_stateless/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh deleted file mode 100755 index 243b669ed..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt" -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-38-avg-10.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless2/exp - done - - rm pruned_transducer_stateless2/exp/*.pt - rm -r data/lang_bpe_500 -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh deleted file mode 100755 index 2d0f80304..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-04-29 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained-epoch-25-avg-6.pt" -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-25-avg-6.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless3/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless3/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless3/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless3/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless3/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless3/exp - done - - rm pruned_transducer_stateless3/exp/*.pt - rm -r data/lang_bpe_500 -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh deleted file mode 100755 index 3d5814c48..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt -ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt -popd - - -log "Export to torchscript model" -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -./pruned_transducer_stateless3/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit-trace 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.trace()" - -./pruned_transducer_stateless3/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ - --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ - --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless3/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_script.pt \ - --decoder-model-filename $repo/exp/decoder_jit_script.pt \ - --joiner-model-filename $repo/exp/joiner_jit_script.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless3/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless3/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless3/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless3/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless3/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless3/exp - done - - rm pruned_transducer_stateless3/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh deleted file mode 100755 index 3d2442d54..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-39-avg-7.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless5/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --num-encoder-layers 18 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless5/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav \ - --num-encoder-layers 18 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless5/exp - ln -s $PWD/$repo/exp/pretrained-epoch-39-avg-7.pt pruned_transducer_stateless5/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless5/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless5/decode.py \ - --decoding-method $method \ - --use-averaged-model 0 \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless5/exp \ - --num-encoder-layers 18 \ - --dim-feedforward 2048 \ - --nhead 8 \ - --encoder-dim 512 \ - --decoder-dim 512 \ - --joiner-dim 512 - done - - rm pruned_transducer_stateless5/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh deleted file mode 100755 index 961dde4f4..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless7/exp - done - - rm pruned_transducer_stateless7/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh deleted file mode 100755 index ba7139efb..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lm/G_4_gram.pt" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7_ctc/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7_ctc/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ - --model-filename $repo/exp/cpu_jit.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7_ctc/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7_ctc/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ - --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --G $repo/data/lm/G_4_gram.pt \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7_ctc/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7_ctc/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7_ctc/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless7_ctc/exp - done - - for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc/ctc_decode.py \ - --epoch 999 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --max-duration $max_duration \ - --use-averaged-model 0 \ - --decoding-method $m \ - --hlg-scale 0.6 \ - --lm-dir data/lm - done - - rm pruned_transducer_stateless7_ctc/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh deleted file mode 100755 index 1ecbc4798..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7_ctc_bs/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7_ctc_bs/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ - --model-filename $repo/exp/cpu_jit.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ - --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --method $m \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" - -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7_ctc_bs/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc_bs/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7_ctc_bs/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7_ctc_bs/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless7_ctc_bs/exp - done - - for m in ctc-decoding 1best; do - ./pruned_transducer_stateless7_ctc_bs/ctc_decode.py \ - --epoch 999 \ - --avg 1 \ - --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --max-duration $max_duration \ - --use-averaged-model 0 \ - --decoding-method $m \ - --hlg-scale 0.6 - done - - rm pruned_transducer_stateless7_ctc_bs/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh deleted file mode 100755 index 37b192a57..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "exp/encoder_jit_trace.pt" -git lfs pull --include "exp/decoder_jit_trace.pt" -git lfs pull --include "exp/joiner_jit_trace.pt" -cd exp -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --decode-chunk-len 32 \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless7_streaming/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Export to torchscript model by torch.jit.trace()" -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --decode-chunk-len 32 \ - --epoch 99 \ - --avg 1 - -log "Decode with models exported by torch.jit.trace()" - -./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --encoder-model-filename $repo/exp/encoder_jit_trace.pt \ - --decoder-model-filename $repo/exp/decoder_jit_trace.pt \ - --joiner-model-filename $repo/exp/joiner_jit_trace.pt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless7_streaming/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless7_streaming/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --decode-chunk-len 32 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless7_streaming/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_streaming/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless7_streaming/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - num_decode_stream=200 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "decoding with $method" - - ./pruned_transducer_stateless7_streaming/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --decode-chunk-len 32 \ - --exp-dir pruned_transducer_stateless7_streaming/exp - done - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --decode-chunk-len 32 \ - --num-decode-streams $num_decode_stream - --exp-dir pruned_transducer_stateless7_streaming/exp - done - - rm pruned_transducer_stateless7_streaming/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh deleted file mode 100755 index 4f2bfac24..000000000 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless8-2022-11-14 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless8/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Export to torchscript model" -./pruned_transducer_stateless8/export.py \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model false \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./pruned_transducer_stateless8/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless8/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless8/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless8/exp - ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless8/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless8/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./pruned_transducer_stateless8/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless8/exp - done - - rm pruned_transducer_stateless8/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh deleted file mode 100755 index 5cbdad16d..000000000 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless2_20220625 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -ln -s pretrained-epoch-24-avg-10.pt pretrained.pt -popd - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./pruned_transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./pruned_transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --simulate-streaming 1 \ - --causal-convolution 1 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p pruned_transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained-epoch-24-avg-10.pt pruned_transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh pruned_transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Simulate streaming decoding with $method" - - ./pruned_transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir pruned_transducer_stateless2/exp \ - --simulate-streaming 1 \ - --causal-convolution 1 - done - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Real streaming decoding with $method" - - ./pruned_transducer_stateless2/streaming_decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --num-decode-streams 100 \ - --exp-dir pruned_transducer_stateless2/exp \ - --left-context 32 \ - --decode-chunk-size 8 \ - --right-context 0 - done - - rm pruned_transducer_stateless2/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh deleted file mode 100755 index f4e2124b1..000000000 --- a/.github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "exp/jit_script_chunk_16_left_128.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./zipformer/jit_pretrained_streaming.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \ - $repo/test_wavs/1089-134686-0001.wav - -for method in greedy_search modified_beam_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Simulated streaming decoding with $method" - - ./zipformer/decode.py \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Chunk-wise streaming decoding with $method" - - ./zipformer/streaming_decode.py \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - rm zipformer/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh deleted file mode 100755 index ff77855a2..000000000 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in fast_beam_search modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless2/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless2/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless2/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless2/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless2/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless2/exp - done - - rm transducer_stateless2/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh b/.github/scripts/run-librispeech-zipformer-2023-05-18.sh deleted file mode 100755 index fb1a0149d..000000000 --- a/.github/scripts/run-librispeech-zipformer-2023-05-18.sh +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "exp/jit_script.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./zipformer/jit_pretrained.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --nn-model-filename $repo/exp/jit_script.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for method in greedy_search modified_beam_search fast_beam_search; do - log "$method" - - ./zipformer/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./zipformer/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - rm zipformer/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh b/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh deleted file mode 100755 index 0026d2109..000000000 --- a/.github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 - -log "Downloading pre-trained model from $repo_url" -git lfs install -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lm/G_4_gram.pt" -git lfs pull --include "exp/jit_script.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer/export.py \ - --exp-dir $repo/exp \ - --use-transducer 1 \ - --use-ctc 1 \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -for method in ctc-decoding 1best; do - ./zipformer/jit_pretrained_ctc.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --model-filename $repo/exp/jit_script.pt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --G $repo/data/lm/G_4_gram.pt \ - --method $method \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in ctc-decoding 1best; do - log "$method" - - ./zipformer/pretrained_ctc.py \ - --use-transducer 1 \ - --use-ctc 1 \ - --method $method \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - --G $repo/data/lm/G_4_gram.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --sample-rate 16000 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in ctc-decoding 1best; do - log "Decoding with $method" - - ./zipformer/ctc_decode.py \ - --use-transducer 1 \ - --use-ctc 1 \ - --decoding-method $method \ - --nbest-scale 1.0 \ - --hlg-scale 0.6 \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --max-duration $max_duration \ - --exp-dir zipformer/exp - done - - rm zipformer/exp/*.pt -fi diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh deleted file mode 100755 index c59921055..000000000 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-mmi-2022-12-08 - -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -pushd $repo/exp -git lfs pull --include "data/lang_bpe_500/3gram.pt" -git lfs pull --include "data/lang_bpe_500/4gram.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/LG.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/cpu_jit.pt" -git lfs pull --include "exp/pretrained.pt" -ln -s pretrained.pt epoch-99.pt -ls -lh *.pt -popd - -log "Export to torchscript model" -./zipformer_mmi/export.py \ - --exp-dir $repo/exp \ - --use-averaged-model false \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 99 \ - --avg 1 \ - --jit 1 - -ls -lh $repo/exp/*.pt - -log "Decode with models exported by torch.jit.script()" - -./zipformer_mmi/jit_pretrained.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --nn-model-filename $repo/exp/cpu_jit.pt \ - --lang-dir $repo/data/lang_bpe_500 \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do - log "$method" - - ./zipformer_mmi/pretrained.py \ - --method $method \ - --checkpoint $repo/exp/pretrained.pt \ - --lang-dir $repo/data/lang_bpe_500 \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p zipformer_mmi/exp - ln -s $PWD/$repo/exp/pretrained.pt zipformer_mmi/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh zipformer_mmi/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram; do - log "Decoding with $method" - - ./zipformer_mmi/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --use-averaged-model 0 \ - --nbest-scale 1.2 \ - --hp-scale 1.0 \ - --max-duration $max_duration \ - --lang-dir $repo/data/lang_bpe_500 \ - --exp-dir zipformer_mmi/exp - done - - rm zipformer_mmi/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-ctc.sh b/.github/scripts/run-pre-trained-ctc.sh deleted file mode 100755 index 7d6449c9a..000000000 --- a/.github/scripts/run-pre-trained-ctc.sh +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -pushd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "CTC greedy search" - -./zipformer/onnx_pretrained_ctc.py \ - --nn-model $repo/model.onnx \ - --tokens $repo/tokens.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "CTC H decoding" - -./zipformer/onnx_pretrained_ctc_H.py \ - --nn-model $repo/model.onnx \ - --tokens $repo/tokens.txt \ - --H $repo/H.fst \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "CTC HL decoding" - -./zipformer/onnx_pretrained_ctc_HL.py \ - --nn-model $repo/model.onnx \ - --words $repo/words.txt \ - --HL $repo/HL.fst \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "CTC HLG decoding" - -./zipformer/onnx_pretrained_ctc_HLG.py \ - --nn-model $repo/model.onnx \ - --words $repo/words.txt \ - --HLG $repo/HLG.fst \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -rm -rf $repo - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo - -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lang_bpe_500/HLG.pt" -git lfs pull --include "data/lang_bpe_500/L.pt" -git lfs pull --include "data/lang_bpe_500/L_disambig.pt" -git lfs pull --include "data/lang_bpe_500/Linv.pt" -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "data/lang_bpe_500/lexicon.txt" -git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" -git lfs pull --include "data/lang_bpe_500/tokens.txt" -git lfs pull --include "data/lang_bpe_500/words.txt" -git lfs pull --include "data/lm/G_3_gram.fst.txt" - -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "CTC decoding" - -./conformer_ctc/pretrained.py \ - --method ctc-decoding \ - --num-classes 500 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "HLG decoding" - -./conformer_ctc/pretrained.py \ - --method 1best \ - --num-classes 500 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --words-file $repo/data/lang_bpe_500/words.txt \ - --HLG $repo/data/lang_bpe_500/HLG.pt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "CTC decoding on CPU with kaldi decoders using OpenFst" - -log "Exporting model with torchscript" - -pushd $repo/exp -ln -s pretrained.pt epoch-99.pt -popd - -./conformer_ctc/export.py \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --jit 1 - -ls -lh $repo/exp - - -log "Generating H.fst, HL.fst" - -./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt - -ls -lh $repo/data/lang_bpe_500 - -log "Decoding with H on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_H.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --H $repo/data/lang_bpe_500/H.fst \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Decoding with HL on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HL.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HL $repo/data/lang_bpe_500/HL.fst \ - --words $repo/data/lang_bpe_500/words.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -log "Decoding with HLG on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HLG.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HLG $repo/data/lang_bpe_500/HLG.fst \ - --words $repo/data/lang_bpe_500/words.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav - -rm -rf $repo - -popd - -log "Test aishell" - -pushd egs/aishell/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall_asr_aishell_conformer_ctc -log "Downloading pre-trained model from $repo_url" -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) -pushd $repo - -git lfs pull --include "exp/pretrained.pt" -git lfs pull --include "data/lang_char/H.fst" -git lfs pull --include "data/lang_char/HL.fst" -git lfs pull --include "data/lang_char/HLG.fst" - -popd - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "CTC decoding" - -log "Exporting model with torchscript" - -pushd $repo/exp -ln -s pretrained.pt epoch-99.pt -popd - -./conformer_ctc/export.py \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --tokens $repo/data/lang_char/tokens.txt \ - --jit 1 - -ls -lh $repo/exp - -ls -lh $repo/data/lang_char - -log "Decoding with H on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_H.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --H $repo/data/lang_char/H.fst \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "Decoding with HL on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HL.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HL $repo/data/lang_char/HL.fst \ - --words $repo/data/lang_char/words.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -log "Decoding with HLG on CPU with OpenFst" - -./conformer_ctc/jit_pretrained_decode_with_HLG.py \ - --nn-model $repo/exp/cpu_jit.pt \ - --HLG $repo/data/lang_char/HLG.fst \ - --words $repo/data/lang_char/words.txt \ - $repo/test_wavs/0.wav \ - $repo/test_wavs/1.wav \ - $repo/test_wavs/2.wav - -rm -rf $repo diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh deleted file mode 100755 index 7b686328d..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless_multi_datasets/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless_multi_datasets/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless_multi_datasets/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless_multi_datasets/exp - done - - rm transducer_stateless_multi_datasets/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh deleted file mode 100755 index a8eeeb514..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in modified_beam_search beam_search fast_beam_search; do - log "$method" - - ./transducer_stateless_multi_datasets/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless_multi_datasets/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless_multi_datasets/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless_multi_datasets/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless_multi_datasets/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless_multi_datasets/exp - done - - rm transducer_stateless_multi_datasets/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh deleted file mode 100755 index 2e2360435..000000000 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -for sym in 1 2 3; do - log "Greedy search with --max-sym-per-frame $sym" - - ./transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame $sym \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -for method in fast_beam_search modified_beam_search beam_search; do - log "$method" - - ./transducer_stateless/pretrained.py \ - --method $method \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -done - -echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" -echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then - mkdir -p transducer_stateless/exp - ln -s $PWD/$repo/exp/pretrained.pt transducer_stateless/exp/epoch-999.pt - ln -s $PWD/$repo/data/lang_bpe_500 data/ - - ls -lh data - ls -lh transducer_stateless/exp - - log "Decoding test-clean and test-other" - - # use a small value for decoding with CPU - max_duration=100 - - for method in greedy_search fast_beam_search modified_beam_search; do - log "Decoding with $method" - - ./transducer_stateless/decode.py \ - --decoding-method $method \ - --epoch 999 \ - --avg 1 \ - --max-duration $max_duration \ - --exp-dir transducer_stateless/exp - done - - rm transducer_stateless/exp/*.pt -fi diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh deleted file mode 100755 index b865f8d13..000000000 --- a/.github/scripts/run-pre-trained-transducer.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env bash - -set -e - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 - -log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url -repo=$(basename $repo_url) - -log "Display test files" -tree $repo/ -ls -lh $repo/test_wavs/*.wav - -log "Beam search decoding" - -./transducer/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint $repo/exp/pretrained.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/aishell.yml b/.github/workflows/aishell.yml index 136e117bd..8b0599fca 100644 --- a/.github/workflows/aishell.yml +++ b/.github/workflows/aishell.yml @@ -11,22 +11,13 @@ on: workflow_dispatch: - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - concurrency: group: aishell-${{ github.ref }} cancel-in-progress: true jobs: generate_build_matrix: - if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'aishell' || github.event_name == 'schedule') + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event.label.name == 'ready' || github.event_name == 'push' || github.event_name == 'aishell') # see https://github.com/pytorch/pytorch/pull/50633 runs-on: ubuntu-latest diff --git a/.github/workflows/train-librispeech.yml b/.github/workflows/librispeech.yml similarity index 95% rename from .github/workflows/train-librispeech.yml rename to .github/workflows/librispeech.yml index 79002a881..6e087b10a 100644 --- a/.github/workflows/train-librispeech.yml +++ b/.github/workflows/librispeech.yml @@ -1,4 +1,4 @@ -name: train librispeech +name: librispeech on: push: branches: @@ -11,7 +11,7 @@ on: workflow_dispatch: concurrency: - group: train-librispeech-${{ github.ref }} + group: librispeech-${{ github.ref }} cancel-in-progress: true jobs: @@ -32,7 +32,7 @@ jobs: python ./.github/scripts/docker/generate_build_matrix.py MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py) echo "::set-output name=matrix::${MATRIX}" - train-librispeech: + librispeech: needs: generate_build_matrix name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }} runs-on: ubuntu-latest diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml deleted file mode 100644 index f092e3c80..000000000 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-03-12 -# stateless transducer + k2 pruned rnnt-loss - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_03_12-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_03_12: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh - - - name: Display decoding results for pruned_transducer_stateless - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless/exp - - cd pruned_transducer_stateless - echo "results for pruned_transducer_stateless" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12 - path: egs/librispeech/ASR/pruned_transducer_stateless/exp/ diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml deleted file mode 100644 index f8f4d9977..000000000 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-04-29 -# stateless pruned transducer (reworked model) + giga speech - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_04_29-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_04_29: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh - - .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh - - - name: Display decoding results for pruned_transducer_stateless2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR - tree pruned_transducer_stateless2/exp - cd pruned_transducer_stateless2/exp - echo "===greedy search===" - find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Display decoding results for pruned_transducer_stateless3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR - tree pruned_transducer_stateless3/exp - cd pruned_transducer_stateless3/exp - echo "===greedy search===" - find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless2 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29 - path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/ - - - name: Upload decoding results for pruned_transducer_stateless3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29 - path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/ diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml deleted file mode 100644 index dc20185da..000000000 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-05-13 -# stateless transducer + k2 pruned rnnt-loss + deeper model - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_05_13-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_05_13: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless5 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless5/exp - - cd pruned_transducer_stateless5 - echo "results for pruned_transducer_stateless5" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless5 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13 - path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/ diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml deleted file mode 100644 index 7e378c9a1..000000000 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-11-11-stateless7 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_11_11_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7/exp - - cd pruned_transducer_stateless7 - echo "results for pruned_transducer_stateless7" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless7 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11 - path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/ diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml deleted file mode 100644 index a2c1a0ad6..000000000 --- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-11-14-stateless8 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_11_14_zipformer_stateless8: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless8 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless8/exp - - cd pruned_transducer_stateless8 - echo "results for pruned_transducer_stateless8" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless8 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14 - path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml deleted file mode 100644 index 500ab1736..000000000 --- a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-12-01-stateless7-ctc -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -jobs: - run_librispeech_2022_11_11_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7_ctc/exp - - cd pruned_transducer_stateless7_ctc - echo "results for pruned_transducer_stateless7_ctc" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===ctc decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01 - path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml deleted file mode 100644 index 1a7f9f594..000000000 --- a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2022 Zengwei Yao - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-12-08-zipformer-mmi -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_12_08_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_12_08_zipformer: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh - - - name: Display decoding results for librispeech zipformer-mmi - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer-mmi/exp - - cd zipformer-mmi - echo "results for zipformer-mmi" - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest===" - find exp/nbest -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest-rescoring-LG===" - find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest-rescoring-LG -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest-rescoring-3-gram===" - find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest-rescoring-3-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===nbest-rescoring-4-gram===" - find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/nbest-rescoring-4-gram -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech zipformer-mmi - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08 - path: egs/librispeech/ASR/zipformer_mmi/exp/ diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml deleted file mode 100644 index 68014e20c..000000000 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-12-29-stateless7-streaming -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_12_29_zipformer_streaming-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_12_29_zipformer_streaming: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7_streaming - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7_streaming/exp - - cd pruned_transducer_stateless7_streaming - echo "results for pruned_transducer_stateless7_streaming" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===streaming greedy search===" - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===streaming fast_beam_search===" - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===streaming modified beam search===" - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - - name: Upload decoding results for librispeech pruned_transducer_stateless7_streaming - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29 - path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/ diff --git a/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml deleted file mode 100644 index 821abc25d..000000000 --- a/.github/workflows/run-librispeech-2023-01-29-stateless7-ctc-bs.yml +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2023-01-29-stateless7-ctc-bs -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -jobs: - run_librispeech_2023_01_29_zipformer_ctc_bs: - if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh - - - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless7_ctc_bs/exp - - cd pruned_transducer_stateless7_ctc_bs - echo "results for pruned_transducer_stateless7_ctc_bs" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===ctc decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc_bs - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29 - path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/ diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml deleted file mode 100644 index 905515dc4..000000000 --- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-conformer-ctc3-2022-11-28 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_11_28_conformer_ctc3: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh - - - name: Display decoding results for librispeech conformer_ctc3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./conformer_ctc3/exp - - cd conformer_ctc3 - echo "results for conformer_ctc3" - echo "===ctc-decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech conformer_ctc3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28 - path: egs/librispeech/ASR/conformer_ctc3/exp/ diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml deleted file mode 100644 index 3fb0920bc..000000000 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-pruned-transducer-stateless3-2022-05-13 -# stateless pruned transducer (reworked model) + giga speech - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh - - - name: Display decoding results for pruned_transducer_stateless3 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR - tree pruned_transducer_stateless3/exp - cd pruned_transducer_stateless3/exp - echo "===greedy search===" - find greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless3 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29 - path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/ diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml deleted file mode 100644 index 67a6f6fc4..000000000 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-streaming-2022-06-26 -# streaming conformer stateless transducer2 - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_streaming_2022_06_26-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_streaming_2022_06_26: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh - - - name: Display decoding results - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./pruned_transducer_stateless2/exp - - cd pruned_transducer_stateless2 - echo "results for pruned_transducer_stateless2" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified_beam_search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for pruned_transducer_stateless2 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26 - path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml b/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml deleted file mode 100644 index 5145fb43c..000000000 --- a/.github/workflows/run-librispeech-streaming-zipformer-2023-05-18.yml +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-streaming-zipformer-2023-05-18 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2023_05_18_streaming_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2023_05_18_streaming_zipformer: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-streaming-zipformer-2023-05-18.sh - - - name: Display decoding results for librispeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer/exp - - cd zipformer - - echo "results for zipformer, simulated streaming decoding" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "results for zipformer, chunk-wise streaming decoding" - echo "===greedy search===" - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/streaming/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - - name: Upload decoding results for librispeech zipformer - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 - path: egs/librispeech/ASR/zipformer/exp/ diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml deleted file mode 100644 index 35ca08a31..000000000 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-2022-04-19 -# stateless transducer + torchaudio rnn-t loss - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2022_04_19-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2022_04_19: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh - - - name: Display decoding results - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless2/exp - - cd transducer_stateless2 - echo "results for transducer_stateless2" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified_beam_search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless2 - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19 - path: egs/librispeech/ASR/transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-zipformer-2023-05-18.yml b/.github/workflows/run-librispeech-zipformer-2023-05-18.yml deleted file mode 100644 index e9d235ad1..000000000 --- a/.github/workflows/run-librispeech-zipformer-2023-05-18.yml +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-zipformer-2023-05-18 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2023_05_18_zipformer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2023_05_18_zipformer: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-zipformer-2023-05-18.sh - - - name: Display decoding results for librispeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer/exp - - cd zipformer - echo "results for zipformer" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech zipformer - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 - path: egs/librispeech/ASR/zipformer/exp/ diff --git a/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml b/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml deleted file mode 100644 index 48f0b1532..000000000 --- a/.github/workflows/run-librispeech-zipformer-ctc-2023-06-14.yml +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-librispeech-zipformer-ctc-2023-06-14 -# zipformer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_librispeech_2023_06_14_zipformer-ctc-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_librispeech_2023_06_14_zipformer_ctc: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh - - - name: Display decoding results for librispeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./zipformer/exp - - cd zipformer - echo "results for zipformer" - echo "===ctc-decoding===" - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===1best===" - find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for librispeech zipformer - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 - path: egs/librispeech/ASR/zipformer/exp/ diff --git a/.github/workflows/run-pretrained-ctc.yml b/.github/workflows/run-pretrained-ctc.yml deleted file mode 100644 index 074a63dfc..000000000 --- a/.github/workflows/run-pretrained-ctc.yml +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-ctc - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - workflow_dispatch: - inputs: - test-run: - description: 'Test (y/n)?' - required: true - default: 'y' - -concurrency: - group: run_pre_trained_ctc-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' || github.event.label.name == 'ctc' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-ctc.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml deleted file mode 100644 index f8caee8e5..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-100h - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh - - - name: Display decoding results for transducer_stateless_multi_datasets - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless_multi_datasets/exp - - cd transducer_stateless_multi_datasets - echo "results for transducer_stateless_multi_datasets" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless_multi_datasets - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21 - path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/ diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml deleted file mode 100644 index 7c3910eb8..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-trandsucer-stateless-multi-datasets-librispeech-960h - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh - - - name: Display decoding results for transducer_stateless_multi_datasets - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless_multi_datasets/exp - - cd transducer_stateless_multi_datasets - echo "results for transducer_stateless_multi_datasets" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless_multi_datasets - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01 - path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/ diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml deleted file mode 100644 index 1b69b97bf..000000000 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-transducer-stateless - -on: - push: - branches: - - master - pull_request: - types: [labeled] - - schedule: - # minute (0-59) - # hour (0-23) - # day of the month (1-31) - # month (1-12) - # day of the week (0-6) - # nightly build at 15:50 UTC time every day - - cron: "50 15 * * *" - -concurrency: - group: run_pre_trained_transducer_stateless-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer_stateless: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/install-kaldifeat.sh - - - name: Cache LibriSpeech test-clean and test-other datasets - id: libri-test-clean-and-test-other-data - uses: actions/cache@v2 - with: - path: | - ~/tmp/download - key: cache-libri-test-clean-and-test-other - - - name: Download LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh - - - name: Prepare manifests for LibriSpeech test-clean and test-other - shell: bash - run: | - .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh - - - name: Cache LibriSpeech test-clean and test-other fbank features - id: libri-test-clean-and-test-other-fbank - uses: actions/cache@v2 - with: - path: | - ~/tmp/fbank-libri - key: cache-libri-fbank-test-clean-and-test-other-v2 - - - name: Compute fbank for LibriSpeech test-clean and test-other - if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' - shell: bash - run: | - .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh - - - name: Inference with pre-trained model - shell: bash - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} - run: | - mkdir -p egs/librispeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank - ls -lh egs/librispeech/ASR/data/* - - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - - .github/scripts/run-pre-trained-transducer-stateless.sh - - - name: Display decoding results for transducer_stateless - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - shell: bash - run: | - cd egs/librispeech/ASR/ - tree ./transducer_stateless/exp - - cd transducer_stateless - echo "results for transducer_stateless" - echo "===greedy search===" - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - - name: Upload decoding results for transducer_stateless - uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07 - path: egs/librispeech/ASR/transducer_stateless/exp/ diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml deleted file mode 100644 index 91d87f1c9..000000000 --- a/.github/workflows/run-pretrained-transducer.yml +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) - -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: run-pre-trained-transducer - -on: - push: - branches: - - master - pull_request: - types: [labeled] - -concurrency: - group: run_pre_trained_transducer-${{ github.ref }} - cancel-in-progress: true - -jobs: - run_pre_trained_transducer: - if: github.event.label.name == 'ready' || github.event_name == 'push' - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - python-version: [3.8] - - fail-fast: false - - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/requirements-ci.txt' - - - name: Install Python dependencies - run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install - pip uninstall -y protobuf - pip install --no-binary protobuf protobuf==3.20.* - - - name: Cache kaldifeat - id: my-cache - uses: actions/cache@v2 - with: - path: | - ~/tmp/kaldifeat - key: cache-tmp-${{ matrix.python-version }}-2023-05-22 - - - name: Install kaldifeat - if: steps.my-cache.outputs.cache-hit != 'true' - shell: bash - run: | - make -j2 _kaldifeat - - - name: Inference with pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs tree - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-pre-trained-transducer.sh From f42258caf8a1c4d19428d98b808986522f630843 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 30 Dec 2023 13:03:26 +0800 Subject: [PATCH 24/46] Update compute_fbank_commonvoice_splits.py (#1437) --- egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py index 0564f6ec6..f31b45aa5 100755 --- a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -109,10 +109,10 @@ def compute_fbank_commonvoice_splits(args): extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) logging.info(f"device: {device}") - set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_audio_duration_mismatch_tolerance(0.05) # 50ms tolerance set_caching_enabled(False) for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) + idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz" From 8136ad775b6cd02bf2ecc60d65e8641b709c2d41 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 4 Jan 2024 13:59:32 +0800 Subject: [PATCH 25/46] Use high_freq -400 in computing fbank features. (#1447) See also https://github.com/k2-fsa/sherpa-onnx/issues/514 --- .../ASR/pruned_transducer_stateless2/pretrained.py | 1 + egs/aishell/ASR/conformer_ctc/pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py | 1 + egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 1 + egs/aishell/ASR/transducer_stateless/pretrained.py | 1 + egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py | 1 + egs/aishell/ASR/transducer_stateless_modified/pretrained.py | 1 + egs/aishell/ASR/zipformer/streaming_decode.py | 1 + egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7/onnx_pretrained.py | 1 + egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + .../jit_trace_pretrained.py | 1 + egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + egs/gigaspeech/ASR/zipformer/streaming_decode.py | 1 + egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py | 1 + .../ASR/conformer_ctc/jit_pretrained_decode_with_H.py | 1 + .../ASR/conformer_ctc/jit_pretrained_decode_with_HL.py | 1 + .../ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py | 1 + egs/librispeech/ASR/conformer_ctc/pretrained.py | 1 + egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py | 1 + egs/librispeech/ASR/conformer_ctc3/pretrained.py | 1 + .../ASR/conv_emformer_transducer_stateless/streaming_decode.py | 1 + .../ASR/conv_emformer_transducer_stateless2/jit_pretrained.py | 1 + .../ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py | 1 + .../conv_emformer_transducer_stateless2/streaming-ncnn-decode.py | 1 + .../ASR/conv_emformer_transducer_stateless2/streaming_decode.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py | 1 + .../ASR/lstm_transducer_stateless/streaming_decode.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py | 1 + .../ASR/lstm_transducer_stateless2/onnx_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py | 1 + .../ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py | 1 + .../ASR/lstm_transducer_stateless2/streaming-onnx-decode.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py | 1 + egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py | 1 + .../ASR/lstm_transducer_stateless3/streaming_decode.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py | 1 + .../ASR/pruned_transducer_stateless/streaming_decode.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py | 1 + .../ASR/pruned_transducer_stateless2/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless3/jit_pretrained.py | 1 + .../ASR/pruned_transducer_stateless3/onnx_pretrained.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py | 1 + .../ASR/pruned_transducer_stateless3/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless4/streaming_decode.py | 1 + .../pruned_transducer_stateless5/onnx_pretrained-streaming.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py | 1 + .../ASR/pruned_transducer_stateless5/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless7/jit_pretrained.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py | 1 + .../pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py | 1 + .../jit_trace_pretrained.py | 1 + .../pruned_transducer_stateless7_streaming/onnx_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_streaming/pretrained.py | 1 + .../streaming-ncnn-decode.py | 1 + .../pruned_transducer_stateless7_streaming/streaming_decode.py | 1 + .../ASR/pruned_transducer_stateless8/jit_pretrained.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py | 1 + egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py | 1 + egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py | 1 + egs/librispeech/ASR/transducer/pretrained.py | 1 + egs/librispeech/ASR/transducer_stateless/pretrained.py | 1 + egs/librispeech/ASR/transducer_stateless2/pretrained.py | 1 + .../ASR/transducer_stateless_multi_datasets/pretrained.py | 1 + egs/librispeech/ASR/zipformer/jit_pretrained.py | 1 + egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py | 1 + egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + egs/librispeech/ASR/zipformer/pretrained.py | 1 + egs/librispeech/ASR/zipformer/pretrained_ctc.py | 1 + egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py | 1 + egs/librispeech/ASR/zipformer_mmi/pretrained.py | 1 + egs/mgb2/ASR/conformer_ctc/pretrained.py | 1 + egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/multi_zh-hans/ASR/zipformer/pretrained.py | 1 + egs/multi_zh_en/ASR/zipformer/pretrained.py | 1 + egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py | 1 + .../ASR/pruned_transducer_stateless7_bbpe/pretrained.py | 1 + egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py | 1 + egs/tedlium3/ASR/transducer_stateless/pretrained.py | 1 + egs/timit/ASR/tdnn_ligru_ctc/pretrained.py | 1 + egs/timit/ASR/tdnn_lstm_ctc/pretrained.py | 1 + .../ASR/pruned_transducer_stateless2/jit_pretrained.py | 1 + egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py | 1 + .../pruned_transducer_stateless5/onnx_pretrained-streaming.py | 1 + .../ASR/pruned_transducer_stateless5/onnx_pretrained.py | 1 + egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py | 1 + .../ASR/pruned_transducer_stateless5/streaming_decode.py | 1 + egs/wenetspeech/ASR/zipformer/streaming_decode.py | 1 + egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py | 1 + egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py | 1 + egs/yesno/ASR/tdnn/jit_pretrained.py | 1 + egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py | 1 + egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py | 1 + egs/yesno/ASR/tdnn/onnx_pretrained.py | 1 + egs/yesno/ASR/tdnn/pretrained.py | 1 + 127 files changed, 127 insertions(+) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py index 75c316eaf..17729e02e 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -242,6 +242,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 66d583396..af1171a6f 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -261,6 +261,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py index 82c10f129..c4aa98358 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/pretrained.py @@ -240,6 +240,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py index ead393e6e..69fe3a40b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -241,6 +241,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py index e61190649..5143f2cae 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -230,6 +230,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py index a92182e8d..8e8e971eb 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -369,6 +369,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py index 0c43bf74b..8fb7ac278 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -227,6 +227,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py index ea5bda4db..12004315b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -250,6 +250,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 6b4f183cf..aa0e07c83 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -317,6 +317,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 7e7213501..9754b4939 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -158,6 +158,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 40f430e13..540e7b61b 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -258,6 +258,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 5d8ca2e11..4a4e9237c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 9e4459247..66a91709e 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py index c3820447a..f54ffbd3c 100755 --- a/egs/aishell/ASR/zipformer/streaming_decode.py +++ b/egs/aishell/ASR/zipformer/streaming_decode.py @@ -572,6 +572,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py index bc3ae7abf..f04632388 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/pretrained.py @@ -239,6 +239,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py index ee898c303..e8b7f71b7 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/pretrained.py @@ -251,6 +251,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py index f5a0dd8c8..a738bb3fb 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/pretrained.py @@ -242,6 +242,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py index cf6ddfa36..52fed7331 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/onnx_pretrained.py @@ -370,6 +370,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py index a22d1b4ba..b6e2451e8 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index dbe65d0a7..018736d26 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -320,6 +320,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py index d84cf04a3..58ee99e6a 100644 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -177,6 +177,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py index 932026868..66fbae378 100644 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -252,6 +252,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 9700dd89e..7252665a7 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -337,6 +337,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py index a76788859..09df2935c 100755 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -553,6 +553,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py index 48fd2612a..458109a3f 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/pretrained.py @@ -264,6 +264,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py index 4bdec9e11..e9acf7e0b 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -195,6 +195,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index d5a1dba3c..5753aa5d3 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -192,6 +192,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py index 216677a23..b6e3333ce 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -191,6 +191,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index df3e4d819..38b60fcb9 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -283,6 +283,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index 76db46cc8..19b26361e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -271,6 +271,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index c37b99cce..a0cdfcf03 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -302,6 +302,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index e5a7c7116..9b8b4cce2 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -623,6 +623,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py index 1fe358c79..58f587c91 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py @@ -184,6 +184,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index a6c69d54f..c8aae04e8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -326,6 +326,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py index 74da9e6c8..1047100fc 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -276,6 +276,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index f5d894a7b..aaed7d31f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -623,6 +623,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py index c07956243..5350a54da 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -266,6 +266,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 119fcf1fd..42c3a5d7f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -251,6 +251,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index f989d9bc0..03472e2c3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -615,6 +615,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index 728b09104..f4ec17221 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -267,6 +267,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3eeaa5397..5bab70fb0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -255,6 +255,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py index 06159e56a..06397965d 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/onnx_pretrained.py @@ -298,6 +298,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index 5d6d97320..dcff088e2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -254,6 +254,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index cbbc77928..6166049ae 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -217,6 +217,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index 487fc2114..df9f6cf3f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -344,6 +344,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index 237591a36..d9e7f3578 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -266,6 +266,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 29a0d4d1a..e39637bd8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -252,6 +252,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index c737e3611..c425b1f46 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -615,6 +615,7 @@ def create_streaming_feature_extractor() -> Fbank: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 02f9f1b03..e06404619 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -277,6 +277,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index f4b01fd06..8586c66d6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -334,6 +334,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 029f55ba0..6923f4d40 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -278,6 +278,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 9c4a13606..d17c3467a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -336,6 +336,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 0669284b3..6d09de6bd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -285,6 +285,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index de3e03da6..8d12eae28 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -368,6 +368,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index abda4e2d4..05e6a6fba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -287,6 +287,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index e7c1affc2..5e1acd735 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -337,6 +337,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index e966aa4b1..229b52e5b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -353,6 +353,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index 6e290e799..2432c6010 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -326,6 +326,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 304fa8693..a9ce75a7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -251,6 +251,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index f65f47fc2..8478a65fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -353,6 +353,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index 5af6dae25..88a05e09d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -225,6 +225,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 86c922cda..4bf11ac24 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py index 280b95984..83dc29324 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py @@ -224,6 +224,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py index d50d231d5..d1b7eec65 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -280,6 +280,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 78e0fa778..323ba2642 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 904c1deae..1e638aa7d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -298,6 +298,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py index da2c6a39a..a39fdee54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained.py @@ -224,6 +224,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py index 653c25e06..80604ef4a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py @@ -280,6 +280,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py index 494a34d97..0ff110370 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py @@ -381,6 +381,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index 5d240cf30..a82f3562b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 914107526..b98756a54 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -298,6 +298,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py index c8301b2da..7116b10fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py @@ -231,6 +231,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py index f2ac1914d..d714670cf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -186,6 +186,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py index 04861ea37..298d1889b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -382,6 +382,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index bc42e8d05..aa2dd17fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py index 883fdcbdd..999f7e0b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py @@ -335,6 +335,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index a0f54b6e1..e27fb4e63 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -320,6 +320,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index 129497d5a..3ce2953c3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -225,6 +225,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 64b38c9d5..c29b8d8c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index fde724866..b3dfab64a 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -196,6 +196,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py index 3888d3544..0cd876551 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py @@ -224,6 +224,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py index 6f2cbaabd..92dea3aa1 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py @@ -280,6 +280,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index 981039b8f..5c6956324 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -262,6 +262,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py index a06d6d684..7698ada79 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py @@ -298,6 +298,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index c2413f5de..4d9bbf4b1 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -235,6 +235,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 5898dd0f5..3b86e319e 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -247,6 +247,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index b69b347ef..2de4182f1 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -247,6 +247,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 4f29d6f1f..83094ea51 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -247,6 +247,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py index a41fbc1c9..52dfd3fb6 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py @@ -222,6 +222,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 660a4bfc6..fcd07ae34 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -285,6 +285,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py index d4ceacefd..eade5a854 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py @@ -167,6 +167,7 @@ def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py index 44546cae5..dd47c0eb6 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -318,6 +318,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index e7c4f40ee..e011c4b24 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -413,6 +413,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index 334376093..662392b5f 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -369,6 +369,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py index eb5cee9cd..ecca758f2 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -161,6 +161,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py index 683a7dc20..a77c3bf2a 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -225,6 +225,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 logging.info(f"Loading H from {args.H}") H = kaldifst.StdVectorFst.read(args.H) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py index 0b94bfa65..6ef944514 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -223,6 +223,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 logging.info(f"Loading HL from {args.HL}") HL = kaldifst.StdVectorFst.read(args.HL) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py index 93569142a..ccb3107ea 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -223,6 +223,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 logging.info(f"Loading HLG from {args.HLG}") HLG = kaldifst.StdVectorFst.read(args.HLG) diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 3104b6084..de0652893 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -303,6 +303,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 9dff2e6fc..408d13576 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -304,6 +304,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py index c9ef16ffa..6990c90a0 100755 --- a/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/jit_pretrained.py @@ -259,6 +259,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 3ba4da5dd..1e7afc777 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -282,6 +282,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py index d30ca98d8..0ab2af527 100755 --- a/egs/mgb2/ASR/conformer_ctc/pretrained.py +++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py @@ -287,6 +287,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py index 77ba0873b..81a16f0ff 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py @@ -249,6 +249,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/multi_zh-hans/ASR/zipformer/pretrained.py b/egs/multi_zh-hans/ASR/zipformer/pretrained.py index 69ff382da..c15db11f7 100755 --- a/egs/multi_zh-hans/ASR/zipformer/pretrained.py +++ b/egs/multi_zh-hans/ASR/zipformer/pretrained.py @@ -303,6 +303,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py index 676272e1f..2fcde550b 100755 --- a/egs/multi_zh_en/ASR/zipformer/pretrained.py +++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py @@ -306,6 +306,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py index 3305f5bd3..8a74ee745 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py @@ -248,6 +248,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py index a23e2a04f..8c966a2f6 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/jit_pretrained.py @@ -226,6 +226,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py index f365986f6..6e07b5949 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/pretrained.py @@ -261,6 +261,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 8a89c3578..9e58fed00 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -256,6 +256,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py index 81afd6a4e..5300fe764 100644 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -270,6 +270,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index 3fdf3b855..0d77bc512 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -196,6 +196,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index 98c746ce5..f06c8c211 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -196,6 +196,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py index f90dd2b43..aee1a2175 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py @@ -285,6 +285,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index c3d67ad92..642de72d7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py index c31db6859..cca26feb0 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained-streaming.py @@ -327,6 +327,7 @@ def create_streaming_feature_extractor() -> OnlineFeature: opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 return OnlineFbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py index c784853ee..4b4ddd332 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/onnx_pretrained.py @@ -376,6 +376,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = args.sample_rate opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py index 1cac20435..17428e19d 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -238,6 +238,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 3a4dc3cb8..27a9b1714 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -378,6 +378,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py index 94c5fae5f..96f339b07 100755 --- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py +++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py @@ -572,6 +572,7 @@ def decode_dataset( opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 log_interval = 100 diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..2c106c4cb 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py @@ -249,6 +249,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py index d05bafcfb..6995ff2ff 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py @@ -260,6 +260,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py index 7581ecb83..e29415ffb 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -142,6 +142,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py index ff8c742af..72127aebd 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -164,6 +164,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 23 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py index 05ba74f9a..f8a057336 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -163,6 +163,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = sample_rate opts.mel_opts.num_bins = 23 + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 72a1d69c8..968a9e9a8 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -186,6 +186,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 987c49de6..bea520998 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -164,6 +164,7 @@ def main(): opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = params.sample_rate opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 fbank = kaldifeat.Fbank(opts) From 716b82cc3ada9e39254acc93465ed85e53d05670 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Fri, 5 Jan 2024 03:21:27 +0100 Subject: [PATCH 26/46] streaming_decode.py, relax the audio range from [-1,+1] to [-10,+10] (#1448) - some AudioTransform classes produce audio signals out of range [-1,+1] - Resample produced 1.0079 - The range [-10,+10] was chosen to still be able to reliably distinguish from the [-32k,+32k] signal... - this is related to : https://github.com/lhotse-speech/lhotse/issues/1254 --- .../streaming_decode.py | 7 ++++++- egs/aishell/ASR/zipformer/streaming_decode.py | 12 ++++++------ .../streaming_decode.py | 7 ++++++- egs/gigaspeech/ASR/zipformer/streaming_decode.py | 7 ++++++- .../streaming_decode.py | 8 +++++++- .../streaming_decode.py | 8 +++++++- .../lstm_transducer_stateless/streaming_decode.py | 8 +++++++- .../lstm_transducer_stateless3/streaming_decode.py | 8 +++++++- .../pruned_transducer_stateless/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless2/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless3/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless4/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless5/streaming_decode.py | 7 ++++++- .../streaming_decode.py | 7 ++++++- .../streaming_decode.py | 7 ++++++- egs/librispeech/ASR/zipformer/streaming_decode.py | 7 ++++++- .../pruned_transducer_stateless5/streaming_decode.py | 8 ++++++++ egs/wenetspeech/ASR/zipformer/streaming_decode.py | 12 ++++++------ 18 files changed, 114 insertions(+), 27 deletions(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index aa0e07c83..a4b5cd588 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -342,7 +342,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/aishell/ASR/zipformer/streaming_decode.py b/egs/aishell/ASR/zipformer/streaming_decode.py index f54ffbd3c..6a7ef2750 100755 --- a/egs/aishell/ASR/zipformer/streaming_decode.py +++ b/egs/aishell/ASR/zipformer/streaming_decode.py @@ -597,12 +597,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - if audio.max() > 1: - logging.warning( - f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." - f"Clipping to [-1, 1]." - ) - audio = np.clip(audio, -1, 1) + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 7252665a7..6a249dd3f 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -362,7 +362,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/gigaspeech/ASR/zipformer/streaming_decode.py b/egs/gigaspeech/ASR/zipformer/streaming_decode.py index 09df2935c..7cada8c9d 100755 --- a/egs/gigaspeech/ASR/zipformer/streaming_decode.py +++ b/egs/gigaspeech/ASR/zipformer/streaming_decode.py @@ -578,7 +578,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 9b8b4cce2..12953c74c 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -681,8 +681,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index aaed7d31f..ddc7dbef1 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -681,8 +681,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index 03472e2c3..14cb0fdfe 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -673,8 +673,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index c425b1f46..f57bdea67 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -673,8 +673,14 @@ def decode_dataset( assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) feature = fbank(samples) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 8586c66d6..4726d9fad 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -359,7 +359,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index d17c3467a..381561359 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -361,7 +361,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 5e1acd735..9113cfaa9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -362,7 +362,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index 229b52e5b..f205ad42f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -378,7 +378,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 8478a65fb..1d980f10e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -378,7 +378,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index e27fb4e63..0961e0d7b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -345,7 +345,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py index 2904f086c..cc2787d76 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/streaming_decode.py @@ -345,7 +345,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 904caf8af..8087c1460 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -577,7 +577,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 27a9b1714..b396aa9b8 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -402,6 +402,14 @@ def decode_dataset( assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py index 96f339b07..cb2cf7d35 100755 --- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py +++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py @@ -597,12 +597,12 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - if audio.max() > 1: - logging.warning( - f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." - f"Clipping to [-1, 1]." - ) - audio = np.clip(audio, -1, 1) + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) From b9b56eb879e694684156b6ba441a1c665ff26e19 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 8 Jan 2024 14:28:07 +0800 Subject: [PATCH 27/46] Minor fixes to the VCTK data prep scripts (#1441) * Update prepare.sh --- egs/vctk/TTS/prepare.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh index 87150ad31..152c7b168 100755 --- a/egs/vctk/TTS/prepare.sh +++ b/egs/vctk/TTS/prepare.sh @@ -7,6 +7,7 @@ set -eou pipefail stage=0 stop_stage=100 +use_edinburgh_vctk_url=true dl_dir=$PWD/download @@ -44,7 +45,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # ln -sfv /path/to/VCTK $dl_dir/VCTK # if [ ! -d $dl_dir/VCTK ]; then - lhotse download vctk $dl_dir + lhotse download vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir fi fi @@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/VCTK mkdir -p data/manifests if [ ! -e data/manifests/.vctk.done ]; then - lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests + lhotse prepare vctk --use-edinburgh-vctk-url ${use_edinburgh_vctk_url} $dl_dir/VCTK data/manifests touch data/manifests/.vctk.done fi fi From 5445ea6df6e250f8f1e9b2df3bb9e54afe104f97 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 8 Jan 2024 15:09:21 +0800 Subject: [PATCH 28/46] Use shuffled LibriSpeech cuts instead (#1450) * use shuffled LibriSpeech cuts instead * leave the old code in comments for reference --- egs/librispeech/ASR/conformer_ctc3/train.py | 15 ++++++++++++--- egs/librispeech/ASR/conformer_mmi/train.py | 16 +++++++++++++--- .../ASR/lstm_transducer_stateless3/train.py | 15 ++++++++++++--- .../ASR/pruned2_knowledge/train.py | 15 ++++++++++++--- .../train.py | 19 ++++++++++++++++--- .../train.py | 11 ++++++++--- egs/librispeech/ASR/zipformer/train.py | 15 ++++++++++++--- egs/librispeech/ASR/zipformer_mmi/train.py | 8 +++++--- 8 files changed, 90 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index 2cd223945..a2f1125ca 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -952,10 +952,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts + # strictly speaking, shuffled training cuts should be used instead + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index f9f80632e..fe8c85f61 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -771,10 +771,20 @@ def run(rank, world_size, args): valid_ali = None librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 6ef4c9860..2c1cef3a3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -989,10 +989,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index a4899f7bd..931341cc4 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -817,10 +817,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 2d915ff87..e1bdce49d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1038,13 +1038,26 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + assert not ( + params.mini_libri and params.full_libri + ), f"Cannot set both mini-libri and full-libri flags to True, now mini-libri {params.mini_libri} and full-libri {params.full_libri}" + if params.mini_libri: train_cuts = librispeech.train_clean_5_cuts() else: - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 565dc7a16..1642ef4b7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -1150,10 +1150,15 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7009f3346..3ccf7d2f1 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1174,10 +1174,19 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index 4b50acdde..dd8949523 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -990,11 +990,13 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - # train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets else: train_cuts = librispeech.train_clean_100_cuts() From e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 9 Jan 2024 15:41:37 +0800 Subject: [PATCH 29/46] fix typo (#1455) --- .../RNN-LM/librispeech/lm-training.rst | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst index 46499a374..e0c90f2a6 100644 --- a/docs/source/recipes/RNN-LM/librispeech/lm-training.rst +++ b/docs/source/recipes/RNN-LM/librispeech/lm-training.rst @@ -4,7 +4,7 @@ Train an RNN language model ====================================== If you have enough text data, you can train a neural network language model (NNLM) to improve -the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from +the WER of your E2E ASR system. This tutorial shows you how to train an RNNLM from scratch. .. HINT:: @@ -15,23 +15,23 @@ scratch. .. note:: This tutorial is based on the LibriSpeech recipe. Please check it out for the necessary - python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set + python scripts for this tutorial. We use the LibriSpeech LM-corpus as the LM training set for illustration purpose. You can also collect your own data. The data format is quite simple: each line should contain a complete sentence, and words should be separated by space. -First, let's download the training data for the RNNLM. This can be done via the +First, let's download the training data for the RNNLM. This can be done via the following command: .. code-block:: bash - $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz + $ wget https://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz $ gzip -d librispeech-lm-norm.txt.gz As we are training a BPE-level RNNLM, we need to tokenize the training text, which requires a BPE tokenizer. This can be achieved by executing the following command: .. code-block:: bash - + $ # if you don't have the BPE $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 $ cd icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500 @@ -56,11 +56,11 @@ sentence length. --out-statistics data/lang_bpe_500/lm_data_stats.txt -The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say -you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt`` +The aforementioned steps can be repeated to create a a validation set for you RNNLM. Let's say +you have a validation set in ``valid.txt``, you can just set ``--lm-data valid.txt`` and ``--lm-archive data/lang_bpe_500/lm-data-valid.pt`` when calling ``./local/prepare_lm_training_data.py``. -After completing the previous steps, the training and testing sets for training RNNLM are ready. +After completing the previous steps, the training and testing sets for training RNNLM are ready. The next step is to train the RNNLM model. The training command is as follows: .. code-block:: bash @@ -77,7 +77,7 @@ The next step is to train the RNNLM model. The training command is as follows: --use-fp16 0 \ --tie-weights 1 \ --embedding-dim 2048 \ - --hidden_dim 2048 \ + --hidden-dim 2048 \ --num-layers 3 \ --batch-size 300 \ --lm-data rnn_lm/data/lang_bpe_500/sorted_lm_data.pt \ @@ -93,12 +93,3 @@ The next step is to train the RNNLM model. The training command is as follows: .. note:: The training of RNNLM can take a long time (usually a couple of days). - - - - - - - - - From 398401ed277d4f895f624a95919c57edbbde4cba Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 14 Jan 2024 14:38:41 +0800 Subject: [PATCH 30/46] Update kaldifeat installation doc (#1460) --- docs/source/for-dummies/environment-setup.rst | 4 ++-- docs/source/for-dummies/model-export.rst | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst index 0cb8ecc1d..a68e9d3ed 100644 --- a/docs/source/for-dummies/environment-setup.rst +++ b/docs/source/for-dummies/environment-setup.rst @@ -66,13 +66,13 @@ to install dependencies of `icefall`_: pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - # If you are using macOS or Windows, please use the following command to install torch and torchaudio + # If you are using macOS, please use the following command to install torch and torchaudio # pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html # Now install k2 # Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example - pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html + pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html # Install the latest version of lhotse diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst index 079ebc712..352a0dc90 100644 --- a/docs/source/for-dummies/model-export.rst +++ b/docs/source/for-dummies/model-export.rst @@ -85,7 +85,7 @@ We can also use it to decode files with the following command: # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html # for how to install kaldifeat - pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html ./tdnn/pretrained.py \ --checkpoint ./tdnn/exp/pretrained.pt \ @@ -162,7 +162,7 @@ To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use: # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html # for how to install kaldifeat - pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html ./tdnn/jit_pretrained.py \ @@ -249,7 +249,7 @@ To use the generated ONNX model files for decoding with `onnxruntime`_, we can u # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html # for how to install kaldifeat - pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + pip install kaldifeat==1.25.3.dev20231221+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html ./tdnn/onnx_pretrained.py \ --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ From 7bdde9174c7c95a32a10d6dcbc3764ecb4873b1d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 16 Jan 2024 21:08:35 +0800 Subject: [PATCH 31/46] A Zipformer recipe with Byte-level BPE for Aishell-1 (#1464) * init commit * Update train.py * Update decode.py * Update RESULTS.md * added `vocab_size` * removed unused softlinks * added scripts for testing pretrained models * set `bpe_model` as required * re-org the bbpe recipe for aishell --- egs/aishell/ASR/RESULTS.md | 56 +- egs/aishell/ASR/zipformer/decode_bbpe.py | 840 ++++++++++++++++ .../ASR/zipformer/jit_pretrained_bbpe.py | 279 ++++++ egs/aishell/ASR/zipformer/pretrained_bbpe.py | 403 ++++++++ egs/aishell/ASR/zipformer/train_bbpe.py | 942 ++++++++++++++++++ 5 files changed, 2518 insertions(+), 2 deletions(-) create mode 100755 egs/aishell/ASR/zipformer/decode_bbpe.py create mode 100755 egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py create mode 100755 egs/aishell/ASR/zipformer/pretrained_bbpe.py create mode 100755 egs/aishell/ASR/zipformer/train_bbpe.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 0b22f41a1..ff9504274 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -2,9 +2,61 @@ ### Aishell training result (Stateless Transducer) +#### Zipformer (Byte-level BPE) + +[./zipformer](./zipformer/) + +It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500. + +##### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + +| | test | dev | comment | +|------------------------|------|------|-----------------------------------------| +| greedy search | 4.54 | 4.31 | --epoch 40 --avg 10 | +| modified beam search | 4.37 | 4.11 | --epoch 40 --avg 10 | +| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | + +```bash +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1" + +./zipformer/train_bbpe.py \ + --world-size 2 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 2 \ + --enable-musan 0 \ + --exp-dir zipformer/exp_bbpe \ + --max-duration 1000 \ + --enable-musan 0 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 10 \ + --spec-aug-time-warp-factor 20 +``` + +Command for decoding is: +```bash +for m in greedy_search modified_beam_search fast_beam_search ; do + ./zipformer/decode_bbpe.py \ + --epoch 40 \ + --avg 10 \ + --exp-dir ./zipformer_bbpe/exp \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --context-size 2 \ + --decoding-method $m +done +``` +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + + #### Zipformer (Non-streaming) -[./zipformer](./zipformer) +[./zipformer](./zipformer/) It's reworked Zipformer with Pruned RNNT loss. **Caution**: It uses `--context-size=1`. @@ -260,7 +312,7 @@ done Pretrained models, training logs, decoding logs, and decoding results are available at -#### Pruned transducer stateless 7 (zipformer) +#### Pruned transducer stateless 7 (Byte-level BPE) See diff --git a/egs/aishell/ASR/zipformer/decode_bbpe.py b/egs/aishell/ASR/zipformer/decode_bbpe.py new file mode 100755 index 000000000..1ec10b059 --- /dev/null +++ b/egs/aishell/ASR/zipformer/decode_bbpe.py @@ -0,0 +1,840 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Mingshuang Luo, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode_bbpe.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode_bbpe.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp_bbpe \ + --lang-dir data/lang_bbpe_500 \ + --bpe-model data/lang_bbpe_500/bbpe.model \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_bbpe/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the byte BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bbpe_500/", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + lexicon: Lexicon, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + hyps.append([lexicon.word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + ref_texts = [] + for tx in supervisions["text"]: + ref_texts.append(byte_encode(tokenize_by_CJK_char(tx))) + + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(ref_texts), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + lexicon: + directory containing the lexicon. + sp: + SentencePiece model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + lexicon=lexicon, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = "".join(ref_text.split()) + + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + lexicon = Lexicon(params.lang_dir) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py new file mode 100755 index 000000000..cd16284f7 --- /dev/null +++ b/egs/aishell/ASR/zipformer/jit_pretrained_bbpe.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./zipformer/export.py \ + --exp-dir ./zipformer_bbpe/exp \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +Usage of this script: + +./zipformer/jit_pretrained.py \ + --nn-model-filename ./zipformer_bbpe/exp/cpu_jit.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + required=True, + help="""Path to the bbpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = model.decoder.blank_id + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + features=features, + feature_lengths=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = smart_byte_decode(sp.decode(hyp)) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/zipformer/pretrained_bbpe.py b/egs/aishell/ASR/zipformer/pretrained_bbpe.py new file mode 100755 index 000000000..387bef98a --- /dev/null +++ b/egs/aishell/ASR/zipformer/pretrained_bbpe.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp_bbpe \ + --tokens ./data/lang_bbpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp_bbpe \ + --causal 1 \ + --tokens ./data/lang_bbpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +- For non-streaming model: + +(1) greedy search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +- For streaming model: + +(1) greedy search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./zipformer/pretrained_bbpe.py \ + --checkpoint ./zipformer/exp_bbpe/pretrained.pt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --bpe ./data/lang_bbpe_500/bbpe.model \ + --method fast_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav + + +You can also use `./zipformer/exp_bbpe/epoch-xx.pt`. + +Note: ./zipformer/exp_bbpe/pretrained.pt is generated by ./zipformer/export_bbpe.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall import smart_byte_decode + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + required=True, + help="""Path to the bbpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + + logging.info("Creating model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(smart_byte_decode(hyp).split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(smart_byte_decode(sp.decode(hyp)).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py new file mode 100755 index 000000000..a2bf96b29 --- /dev/null +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -0,0 +1,942 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./zipformer/train_bbpe.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_bbpe \ + --max-duration 350 + +# For mix precision training: + +./zipformer/train_bbpe.py \ + --world-size 8 \ + --num-epochs 12 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp_bbpe \ + --max-duration 750 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from typing import Optional, Tuple, Union + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from train import ( + LRSchedulerType, + add_model_arguments, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + set_batch_count, +) + +from icefall import byte_encode, diagnostics +from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + tokenize_by_CJK_char, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer_bbpe/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bbpe_500/bbpe.model", + help="Path to the Byte BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bbpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + train_cuts = aishell.train_cuts() + valid_cuts = aishell.valid_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 15 seconds + # + # Caution: There is a reason to select 15.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_cuts = train_cuts.map(tokenize_and_encode_text) + + valid_cuts = valid_cuts.map(tokenize_and_encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_dl = aishell.valid_dataloaders(valid_cuts) + + if False and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The sentence piece model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 5dfc3ed7f995f30a4b64e57071a095566f8126ea Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 21 Jan 2024 02:10:42 +0800 Subject: [PATCH 32/46] Fix buffer size of DynamicBucketingSampler (#1468) * Fix buffer size * Fix for flake8 --------- Co-authored-by: yifanyeung --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++- egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++ .../ASR/transducer_stateless_modified-2/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 3 ++- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++- .../ASR_v2/pruned_transducer_stateless7/asr_datamodule.py | 2 ++ egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 ++ egs/ami/SURT/dprnn_zipformer/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless7/asr_datamodule.py | 2 ++ .../commonvoice_fr.py | 2 ++ egs/csj/ASR/local/utils/asr_datamodule.py | 2 ++ egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 ++ egs/gigaspeech/ASR/zipformer/asr_datamodule.py | 2 ++ egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py | 2 ++ egs/libriheavy/ASR/zipformer/asr_datamodule.py | 2 ++ egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py | 2 ++ egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless3/asr_datamodule.py | 6 ++++++ .../ASR/pruned_transducer_stateless7/gigaspeech.py | 2 ++ egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++ egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py | 2 ++ egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py | 2 ++ egs/ljspeech/TTS/vits/tts_datamodule.py | 2 ++ egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 2 ++ egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py | 2 ++ egs/multi_zh_en/ASR/zipformer/asr_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 2 ++ egs/swbd/ASR/conformer_ctc/asr_datamodule.py | 3 ++- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 3 ++- egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py | 2 ++ egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 ++ egs/vctk/TTS/vits/tts_datamodule.py | 2 ++ .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 3 ++- .../ASR/pruned_transducer_stateless5/asr_datamodule.py | 2 ++ egs/yesno/ASR/tdnn/asr_datamodule.py | 2 ++ 37 files changed, 78 insertions(+), 6 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index d491996b2..e29dd8ab5 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -288,8 +288,9 @@ class Aidatatang_200zhAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, - buffer_size=50000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 6abe6c084..aacbd153d 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -275,6 +275,8 @@ class AishellAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py index cd8dd821c..ed453afd2 100644 --- a/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/asr_datamodule.py @@ -226,6 +226,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py index 8f6a88f59..f9cdfb621 100644 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -296,6 +296,8 @@ class AiShell2AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py index e6db2651f..c10456da5 100644 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -306,7 +306,8 @@ class Aishell4AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=100000, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py index 5ad80817a..410741215 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -288,7 +288,8 @@ class AlimeetingAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py index 9d288218a..6b56c8a6a 100644 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -263,6 +263,8 @@ class AlimeetingAsrDataModule: max_cuts=self.args.max_cuts, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) logging.info("About to create train dataloader") diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py index 79474f1d8..554facfc1 100644 --- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -269,6 +269,8 @@ class AmiAsrDataModule: max_cuts=self.args.max_cuts, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) logging.info("About to create train dataloader") diff --git a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py index 1549c1631..ea8b62242 100644 --- a/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/ami/SURT/dprnn_zipformer/asr_datamodule.py @@ -254,6 +254,8 @@ class AmiAsrDataModule: max_cuts=self.args.max_cuts, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index 546e9f9dd..c40d9419b 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -308,6 +308,8 @@ class CommonVoiceAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py index cafa4111d..79cf86b84 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -310,6 +310,8 @@ class CommonVoiceAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py index 042b6ecbf..7bf7bdef0 100644 --- a/egs/csj/ASR/local/utils/asr_datamodule.py +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -336,6 +336,8 @@ class CSJAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index a93e224d5..569978424 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -261,6 +261,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index b5b27ce95..40339365c 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 6adfdbfbb..850ab7c10 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -311,6 +311,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py index c1abdbdb5..500df9ea4 100644 --- a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py +++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py @@ -256,6 +256,8 @@ class LibriCssAsrDataModule: max_cuts=self.args.max_cuts, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index df761c1b8..e23c9b1b7 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -310,6 +310,8 @@ class LibriHeavyAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 690003377..1a4c9a532 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -341,6 +341,8 @@ class LibriHeavyAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index ee7556e49..be36c06b6 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -286,6 +286,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, bucket_method="equal_duration", drop_last=True, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 057624272..87c62789e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -223,6 +223,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) @@ -256,6 +258,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=False, ) logging.info("About to create dev dataloader") @@ -282,6 +286,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, ) logging.debug("About to create test dataloader") test_dl = DataLoader( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py index cd432fd6f..306f30c2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -294,6 +294,8 @@ class GigaSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c500eb3e5..dd9e9ef1f 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -311,6 +311,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py index 3acd22ae4..84bd3fc4b 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -304,6 +304,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py index 2f8e658c5..e1a29bd9c 100644 --- a/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py +++ b/egs/librispeech/WSASR/conformer_ctc2/asr_datamodule.py @@ -227,6 +227,8 @@ class LibriSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 81bb9ed13..8ff868bc8 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -196,6 +196,8 @@ class LJSpeechTtsDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py index 7753d1674..48921d71f 100644 --- a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py @@ -266,6 +266,8 @@ class MGB2AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py index 02cfa1346..341579acb 100644 --- a/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py @@ -297,6 +297,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py index be6e94472..662ae01c5 100644 --- a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py +++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py @@ -294,6 +294,8 @@ class AsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index cf70fc0f8..7cd6771ce 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -236,6 +236,8 @@ class SPGISpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) logging.info("About to create train dataloader") diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py index ce8634a1d..0f6f02e8d 100644 --- a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -298,8 +298,9 @@ class SwitchBoardAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, - buffer_size=50000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 5269a1778..6f0833db6 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -306,8 +306,9 @@ class TAL_CSASRAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, num_cuts_for_bins_estimate=20000, - buffer_size=60000, drop_last=self.args.drop_last, ) else: diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index d4a9e4bc9..a67cf8d04 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -256,6 +256,8 @@ class TedLiumAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index 5d1b3c367..8606a490b 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -222,6 +222,8 @@ class TimitAsrDataModule(DataModule): max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py index 8b2a96b09..52fc5179f 100644 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -204,6 +204,8 @@ class VctkTtsDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 1dbfb9709..58da1d68c 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -292,7 +292,8 @@ class WenetSpeechAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=300000, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py index 7594fb28e..7b37b1331 100644 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -296,6 +296,8 @@ class Xbmu_AmdoAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, ) else: diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index dc66b217d..b9ce8fb4e 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -193,6 +193,8 @@ class YesNoAsrDataModule(DataModule): max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=True, ) else: From ebe97a07b082f9bee4a18b5f8e54c453187a74bb Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 23 Jan 2024 16:26:24 +0800 Subject: [PATCH 33/46] Reworked README.md (#1470) * Rework README.md Co-authored-by: Fangjun Kuang --------- Co-authored-by: Fangjun Kuang --- README.md | 439 +++++++++++++++++++++++++----------------------------- 1 file changed, 201 insertions(+), 238 deletions(-) diff --git a/README.md b/README.md index 15e9e17e6..61920be65 100644 --- a/README.md +++ b/README.md @@ -2,46 +2,83 @@ -## Introduction +# Introduction -icefall contains ASR recipes for various datasets -using . +The icefall peoject contains speech related recipes for various datasets +using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse). -You can use to deploy models -trained with icefall. +You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models +in icefall; these frameworks also support models not included in icefall; please refer to respective documents for more details. You can try pre-trained models from within your browser without the need -to download or install anything by visiting -See for more details. +to download or install anything by visiting this [huggingface space](https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition). +Please refer to [document](https://k2-fsa.github.io/icefall/huggingface/spaces.html) for more details. -## Installation +# Installation -Please refer to +Please refer to [document](https://icefall.readthedocs.io/en/latest/installation/index.html) for installation. -## Recipes +# Recipes -Please refer to -for more information. +Please refer to [document](https://icefall.readthedocs.io/en/latest/recipes/index.html) +for more details. -We provide the following recipes: +## ASR: Automatic Speech Recognition +### Supported Datasets - [yesno][yesno] - - [LibriSpeech][librispeech] - - [GigaSpeech][gigaspeech] - - [AMI][ami] + + - [Aidatatang_200zh][aidatatang_200zh] - [Aishell][aishell] - [Aishell2][aishell2] - [Aishell4][aishell4] + - [Alimeeting][alimeeting] + - [AMI][ami] + - [CommonVoice][commonvoice] + - [Corpus of Spontaneous Japanese][csj] + - [GigaSpeech][gigaspeech] + - [LibriCSS][libricss] + - [LibriSpeech][librispeech] + - [Libriheavy][libriheavy] + - [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2] + - [PeopleSpeech][peoplespeech] + - [SPGISpeech][spgispeech] + - [Switchboard][swbd] - [TIMIT][timit] - [TED-LIUM3][tedlium3] - - [Aidatatang_200zh][aidatatang_200zh] - - [WenetSpeech][wenetspeech] - - [Alimeeting][alimeeting] - - [Switchboard][swbd] - [TAL_CSASR][tal_csasr] + - [Voxpopuli][voxpopuli] + - [XBMU-AMDO31][xbmu-amdo31] + - [WenetSpeech][wenetspeech] + +More datasets will be added in the future. -### yesno +### Supported Models + +The [LibriSpeech][librispeech] recipe supports the most comprehensive set of models, you are welcome to try them out. + +#### CTC + - TDNN LSTM CTC + - Conformer CTC + - Zipformer CTC + +#### MMI + - Conformer MMI + - Zipformer MMI + +#### Transducer + - Conformer-based Encoder + - LSTM-based Encoder + - Zipformer-based Encoder + - LSTM-based Predictor + - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/) + +If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details. + +We would like to highlight the performance of some of the recipes here. + +### [yesno][yesno] This is the simplest ASR recipe in `icefall` and can be run on CPU. Training takes less than 30 seconds and gives you the following WER: @@ -52,350 +89,264 @@ Training takes less than 30 seconds and gives you the following WER: We provide a Colab notebook for this recipe: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tIjjzaJc3IvGyKiMCDWO-TSnBgkcuN3B?usp=sharing) -### LibriSpeech +### [LibriSpeech][librispeech] -Please see +Please see [RESULTS.md](https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md) for the **latest** results. -We provide 5 models for this recipe: - -- [conformer CTC model][LibriSpeech_conformer_ctc] -- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc] -- [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer] -- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless] -- [Transducer: Zipformer encoder + Embedding decoder][LibriSpeech_zipformer] - -#### Conformer CTC Model - -The best WER we currently have is: +#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conformer_ctc) | | test-clean | test-other | |-----|------------|------------| | WER | 2.42 | 5.73 | -We provide a Colab notebook to run a pre-trained conformer CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing) -#### TDNN LSTM CTC Model - -The WER for this model is: +#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/tdnn_lstm_ctc) | | test-clean | test-other | |-----|------------|------------| | WER | 6.59 | 17.69 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-iSfQMp2So-We_Uu49N4AAcMInB72u9z?usp=sharing) -#### Transducer: Conformer encoder + LSTM decoder +#### [Transducer (Conformer Encoder + LSTM Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer) -Using Conformer as encoder and LSTM as decoder. +| | test-clean | test-other | +|---------------|------------|------------| +| greedy search | 3.07 | 7.51 | -The best WER with greedy search is: +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) -| | test-clean | test-other | -|-----|------------|------------| -| WER | 3.07 | 7.51 | +#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/transducer) -We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) - -#### Transducer: Conformer encoder + Embedding decoder - -Using Conformer as encoder. The decoder consists of 1 embedding layer -and 1 convolutional layer. - -The best WER using modified beam search with beam size 4 is: - -| | test-clean | test-other | -|-----|------------|------------| -| WER | 2.56 | 6.27 | - -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. - -We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) +| | test-clean | test-other | +|---------------------------------------|------------|------------| +| modified_beam_search (`beam_size=4`) | 2.56 | 6.27 | -#### k2 pruned RNN-T +We provide a Colab notebook to run test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) + + +#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer) + +WER (modified_beam_search `beam_size=4` unless further stated) + +1. LibriSpeech-960hr | Encoder | Params | test-clean | test-other | epochs | devices | |-----------------|--------|------------|------------|---------|------------| -| zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 | -| zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 | -| zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 | -| zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 | +| Zipformer | 65.5M | 2.21 | 4.79 | 50 | 4 32G-V100 | +| Zipformer-small | 23.2M | 2.42 | 5.73 | 50 | 2 32G-V100 | +| Zipformer-large | 148.4M | 2.06 | 4.63 | 50 | 4 32G-V100 | +| Zipformer-large | 148.4M | 2.00 | 4.38 | 174 | 8 80G-A100 | -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. +2. LibriSpeech-960hr + GigaSpeech -#### k2 pruned RNN-T + GigaSpeech - -| | test-clean | test-other | -|-----|------------|------------| -| WER | 1.78 | 4.08 | - -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. - -#### k2 pruned RNN-T + GigaSpeech + CommonVoice - -| | test-clean | test-other | -|-----|------------|------------| -| WER | 1.90 | 3.98 | - -Note: No auxiliary losses are used in the training and no LMs are used -in the decoding. +| Encoder | Params | test-clean | test-other | +|-----------------|--------|------------|------------| +| Zipformer | 65.5M | 1.78 | 4.08 | -### GigaSpeech +3. LibriSpeech-960hr + GigaSpeech + CommonVoice -We provide three models for this recipe: +| Encoder | Params | test-clean | test-other | +|-----------------|--------|------------|------------| +| Zipformer | 65.5M | 1.90 | 3.98 | -- [Conformer CTC model][GigaSpeech_conformer_ctc] -- [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][GigaSpeech_pruned_transducer_stateless2]. -- [Transducer: Zipformer encoder + Embedding decoder][GigaSpeech_zipformer] -#### Conformer CTC +### [GigaSpeech][gigaspeech] + +#### [Conformer CTC](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/conformer_ctc) | | Dev | Test | |-----|-------|-------| | WER | 10.47 | 10.58 | -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/pruned_transducer_stateless2) + +Conformer Encoder + Stateless Predictor + k2 Pruned RNN-T Loss | | Dev | Test | |----------------------|-------|-------| -| greedy search | 10.51 | 10.73 | -| fast beam search | 10.50 | 10.69 | -| modified beam search | 10.40 | 10.51 | +| greedy_search | 10.51 | 10.73 | +| fast_beam_search | 10.50 | 10.69 | +| modified_beam_search | 10.40 | 10.51 | -#### Transducer: Zipformer encoder + Embedding decoder +#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/gigaspeech/ASR/zipformer) | | Dev | Test | |----------------------|-------|-------| -| greedy search | 10.31 | 10.50 | -| fast beam search | 10.26 | 10.48 | -| modified beam search | 10.25 | 10.38 | +| greedy_search | 10.31 | 10.50 | +| fast_beam_search | 10.26 | 10.48 | +| modified_beam_search | 10.25 | 10.38 | -### Aishell +### [Aishell][aishell] -We provide three models for this recipe: [conformer CTC model][Aishell_conformer_ctc], -[TDNN LSTM CTC model][Aishell_tdnn_lstm_ctc], and [Transducer Stateless Model][Aishell_pruned_transducer_stateless7], - -#### Conformer CTC Model - -The best CER we currently have is: - -| | test | -|-----|------| -| CER | 4.26 | - -#### TDNN LSTM CTC Model - -The CER for this model is: +#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/tdnn_lstm_ctc) | | test | |-----|-------| | CER | 10.16 | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jbyzYq3ytm6j2nlEt-diQm-6QVWyDDEa?usp=sharing) -#### Transducer Stateless Model - -The best CER we currently have is: +#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/transducer_stateless) | | test | |-----|------| | CER | 4.38 | -We provide a Colab notebook to run a pre-trained TransducerStateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14XaT2MhnBkK-3_RqqWq3K90Xlbin-GZC?usp=sharing) + +#### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell/ASR/zipformer) + +WER (modified_beam_search `beam_size=4`) + +| Encoder | Params | dev | test | epochs | +|-----------------|--------|-----|------|---------| +| Zipformer | 73.4M | 4.13| 4.40 | 55 | +| Zipformer-small | 30.2M | 4.40| 4.67 | 55 | +| Zipformer-large | 157.3M | 4.03| 4.28 | 56 | -### Aishell2 +### [Aishell4][aishell4] -We provide one model for this recipe: [Transducer Stateless Model][Aishell2_pruned_transducer_stateless5]. - -#### Transducer Stateless Model - -The best WER we currently have is: - -| | dev-ios | test-ios | -|-----|------------|------------| -| WER | 5.32 | 5.56 | - - -### Aishell4 - -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5]. - -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets) - -The best CER we currently have is: +#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/aishell4/ASR/pruned_transducer_stateless5) +1 Trained with all subsets: | | test | |-----|------------| | CER | 29.08 | - -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) -### TIMIT +### [TIMIT][timit] -We provide two models for this recipe: [TDNN LSTM CTC model][TIMIT_tdnn_lstm_ctc] -and [TDNN LiGRU CTC model][TIMIT_tdnn_ligru_ctc]. +#### [TDNN LSTM CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_lstm_ctc) -#### TDNN LSTM CTC Model - -The best PER we currently have is: - -||TEST| -|--|--| +| |TEST| +|---|----| |PER| 19.71% | -We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Hs9DA4V96uapw_30uNp32OMJgkuR5VVd?usp=sharing) -#### TDNN LiGRU CTC Model +#### [TDNN LiGRU CTC](https://github.com/k2-fsa/icefall/tree/master/egs/timit/ASR/tdnn_ligru_ctc) -The PER for this model is: - -||TEST| -|--|--| +| |TEST| +|---|----| |PER| 17.66% | -We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) -### TED-LIUM3 +### [TED-LIUM3][tedlium3] -We provide two models for this recipe: [Transducer Stateless: Conformer encoder + Embedding decoder][TED-LIUM3_transducer_stateless] and [Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TED-LIUM3_pruned_transducer_stateless]. +#### [Transducer (Conformer Encoder + Embedding Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless) -#### Transducer Stateless: Conformer encoder + Embedding decoder - -The best WER using modified beam search with beam size 4 is: - -| | dev | test | -|-----|-------|--------| -| WER | 6.91 | 6.33 | - -Note: No auxiliary losses are used in the training and no LMs are used in the decoding. - -We provide a Colab notebook to run a pre-trained Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing) - -#### Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss - -The best WER using modified beam search with beam size 4 is: - -| | dev | test | -|-----|-------|--------| -| WER | 6.77 | 6.14 | - -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) +| | dev | test | +|--------------------------------------|-------|--------| +| modified_beam_search (`beam_size=4`) | 6.91 | 6.33 | -### Aidatatang_200zh +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing) -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aidatatang_200zh_pruned_transducer_stateless2]. +#### [Transducer (pruned_transducer_stateless)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/pruned_transducer_stateless) -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss +| | dev | test | +|--------------------------------------|-------|--------| +| modified_beam_search (`beam_size=4`) | 6.77 | 6.14 | + +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) + + +### [Aidatatang_200zh][aidatatang_200zh] + +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2) | | Dev | Test | |----------------------|-------|-------| -| greedy search | 5.53 | 6.59 | -| fast beam search | 5.30 | 6.34 | -| modified beam search | 5.27 | 6.33 | +| greedy_search | 5.53 | 6.59 | +| fast_beam_search | 5.30 | 6.34 | +| modified_beam_search | 5.27 | 6.33 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wNSnSj3T5oOctbh5IGCa393gKOoQw2GH?usp=sharing) -### WenetSpeech +### [WenetSpeech][wenetspeech] -We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5]. - -#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR) +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless2) | | Dev | Test-Net | Test-Meeting | |----------------------|-------|----------|--------------| -| greedy search | 7.80 | 8.75 | 13.49 | -| modified beam search| 7.76 | 8.71 | 13.41 | -| fast beam search | 7.94 | 8.74 | 13.80 | +| greedy_search | 7.80 | 8.75 | 13.49 | +| fast_beam_search | 7.94 | 8.74 | 13.80 | +| modified_beam_search | 7.76 | 8.71 | 13.41 | + +We provide a Colab notebook to run the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) + +#### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5) -#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset) -**Streaming**: | | Dev | Test-Net | Test-Meeting | |----------------------|-------|----------|--------------| | greedy_search | 8.78 | 10.12 | 16.16 | -| modified_beam_search | 8.53| 9.95 | 15.81 | | fast_beam_search| 9.01 | 10.47 | 16.28 | +| modified_beam_search | 8.53| 9.95 | 15.81 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) -### Alimeeting +### [Alimeeting][alimeeting] -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Alimeeting_pruned_transducer_stateless2]. - -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with far subset) +#### [Transducer (pruned_transducer_stateless2)](https://github.com/k2-fsa/icefall/tree/master/egs/alimeeting/ASR/pruned_transducer_stateless2) | | Eval | Test-Net | |----------------------|--------|----------| -| greedy search | 31.77 | 34.66 | -| fast beam search | 31.39 | 33.02 | -| modified beam search | 30.38 | 34.25 | +| greedy_search | 31.77 | 34.66 | +| fast_beam_search | 31.39 | 33.02 | +| modified_beam_search | 30.38 | 34.25 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) -### TAL_CSASR +### [TAL_CSASR][tal_csasr] -We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5]. -#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss +#### [Transducer (pruned_transducer_stateless5)](https://github.com/k2-fsa/icefall/tree/master/egs/tal_csasr/ASR/pruned_transducer_stateless5) The best results for Chinese CER(%) and English WER(%) respectively (zh: Chinese, en: English): |decoding-method | dev | dev_zh | dev_en | test | test_zh | test_en | |--|--|--|--|--|--|--| |greedy_search| 7.30 | 6.48 | 19.19 |7.39| 6.66 | 19.13| -|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | |fast_beam_search| 7.18 | 6.39| 18.90 | 7.27| 6.55 | 18.77| +|modified_beam_search| 7.15 | 6.35 | 18.95 | 7.22| 6.50 | 18.70 | -We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) -## Deployment with C++ +## TTS: Text-to-Speech -Once you have trained a model in icefall, you may want to deploy it with C++, -without Python dependencies. +### Supported Datasets -Please refer to the documentation - + - [LJSpeech][ljspeech] + - [VCTK][vctk] + +### Supported Models + + - [VITS](https://arxiv.org/abs/2106.06103) + +# Deployment with C++ + +Once you have trained a model in icefall, you may want to deploy it with C++ without Python dependencies. + +Please refer to the [document](https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/conformer_ctc.html#deployment-with-c) for how to do this. We also provide a Colab notebook, showing you how to run a torch scripted model in [k2][k2] with C++. Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BIGLWzS36isskMXHKcqC9ysN6pspYXs_?usp=sharing) -[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc -[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc -[LibriSpeech_transducer]: egs/librispeech/ASR/transducer -[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless -[LibriSpeech_zipformer]: egs/librispeech/ASR/zipformer -[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc -[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc -[Aishell_pruned_transducer_stateless7]: egs/aishell/ASR/pruned_transducer_stateless7_bbpe -[Aishell2_pruned_transducer_stateless5]: egs/aishell2/ASR/pruned_transducer_stateless5 -[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5 -[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc -[TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc -[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless -[TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless -[GigaSpeech_conformer_ctc]: egs/gigaspeech/ASR/conformer_ctc -[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2 -[GigaSpeech_zipformer]: egs/gigaspeech/ASR/zipformer -[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 -[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 -[WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5 -[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2 -[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5 [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR @@ -411,3 +362,15 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [ami]: egs/ami [swbd]: egs/swbd/ASR [k2]: https://github.com/k2-fsa/k2 +[commonvoice]: egs/commonvoice/ASR +[csj]: egs/csj/ASR +[libricss]: egs/libricss/SURT +[libriheavy]: egs/libriheavy/ASR +[mgb2]: egs/mgb2/ASR +[peoplespeech]: egs/peoplespeech/ASR +[spgispeech]: egs/spgispeech/ASR +[voxpopuli]: egs/voxpopuli/ASR +[xbmu-amdo31]: egs/xbmu-amdo31/ASR + +[vctk]: egs/vctk/TTS +[ljspeech]: egs/ljspeech/TTS \ No newline at end of file From 559ed150bb73e3e2e89c703bac4c37744e516e8e Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 23 Jan 2024 22:51:09 +0800 Subject: [PATCH 34/46] Fix typo (#1471) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 61920be65..f92c85ad4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # Introduction -The icefall peoject contains speech related recipes for various datasets +The icefall project contains speech-related recipes for various datasets using [k2-fsa](https://github.com/k2-fsa/k2) and [lhotse](https://github.com/lhotse-speech/lhotse). You can use [sherpa](https://github.com/k2-fsa/sherpa), [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn) or [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) for deployment with models @@ -373,4 +373,4 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [xbmu-amdo31]: egs/xbmu-amdo31/ASR [vctk]: egs/vctk/TTS -[ljspeech]: egs/ljspeech/TTS \ No newline at end of file +[ljspeech]: egs/ljspeech/TTS From 9c494a3329d531e4ed10117ec0b6f244d0a61ce3 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 25 Jan 2024 18:41:43 +0800 Subject: [PATCH 35/46] typos fixed (#1472) --- README.md | 12 ++++++------ .../ASR/local/compute_fbank_peoples_speech_splits.py | 4 ++-- .../ASR/local/compute_fbank_wenetspeech_splits.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f92c85ad4..cc817702b 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt | | test-clean | test-other | |---------------|------------|------------| -| greedy search | 3.07 | 7.51 | +| greedy_search | 3.07 | 7.51 | We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) @@ -127,7 +127,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt | modified_beam_search (`beam_size=4`) | 2.56 | 6.27 | -We provide a Colab notebook to run test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing) #### [Transducer (Zipformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/zipformer) @@ -147,14 +147,14 @@ WER (modified_beam_search `beam_size=4` unless further stated) | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| Zipformer | 65.5M | 1.78 | 4.08 | +| Zipformer | 65.5M | 1.78 | 4.08 | 3. LibriSpeech-960hr + GigaSpeech + CommonVoice | Encoder | Params | test-clean | test-other | |-----------------|--------|------------|------------| -| Zipformer | 65.5M | 1.90 | 3.98 | +| Zipformer | 65.5M | 1.90 | 3.98 | ### [GigaSpeech][gigaspeech] @@ -246,7 +246,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt ### [TED-LIUM3][tedlium3] -#### [Transducer (Conformer Encoder + Embedding Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless) +#### [Transducer (Conformer Encoder + Stateless Predictor)](https://github.com/k2-fsa/icefall/tree/master/egs/tedlium3/ASR/transducer_stateless) | | dev | test | |--------------------------------------|-------|--------| @@ -287,7 +287,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt | fast_beam_search | 7.94 | 8.74 | 13.80 | | modified_beam_search | 7.76 | 8.71 | 13.41 | -We provide a Colab notebook to run the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) +We provide a Colab notebook to test the pre-trained model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing) #### [Transducer **Streaming** (pruned_transducer_stateless5) ](https://github.com/k2-fsa/icefall/tree/master/egs/wenetspeech/ASR/pruned_transducer_stateless5) diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py index c2ab3d07d..6f05b9f8c 100755 --- a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py +++ b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py @@ -67,14 +67,14 @@ def get_args(): "--start", type=int, default=0, - help="Process pieces starting from this number (inclusive).", + help="Process pieces starting from this number (included).", ) parser.add_argument( "--stop", type=int, default=-1, - help="Stop processing pieces until this number (exclusive).", + help="Stop processing pieces until this number (excluded).", ) return parser.parse_args() diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index 99d39bbdc..a87801462 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -78,14 +78,14 @@ def get_parser(): "--start", type=int, default=0, - help="Process pieces starting from this number (inclusive).", + help="Process pieces starting from this number (included).", ) parser.add_argument( "--stop", type=int, default=-1, - help="Stop processing pieces until this number (exclusive).", + help="Stop processing pieces until this number (excluded).", ) return parser From c401a2646b347bf1fff0c2ce1a4ee13b0f482448 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 26 Jan 2024 15:50:11 +0800 Subject: [PATCH 36/46] minor fix of zipformer/optim.py (#1474) --- egs/librispeech/ASR/zipformer/optim.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 714d8db9a..aaffbfed5 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -298,11 +298,14 @@ class ScaledAdam(BatchedOptimizer): # case 2 or case 4 # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: - assert "named_params" in cur_group - name_list = [x[0] for x in cur_group["named_params"]] - p_list = [x[1] for x in cur_group["named_params"]] - del cur_group["named_params"] - cur_group["params"] = p_list + if "named_params" in cur_group: + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + else: + assert "params" in cur_group + name_list = ["foo" for _ in cur_group["params"]] param_groups.append(cur_group) param_groups_names.append(name_list) From 8d39f9508bd5b27627165696758cad2e96dca20b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 26 Jan 2024 19:18:33 +0800 Subject: [PATCH 37/46] Fix torchscript export to use tokens.txt instead of lang_dir (#1475) --- .../pruned_transducer_stateless2/export.py | 25 +++++++++------ .../ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../scaling_converter.py | 1 + .../pruned_transducer_stateless2/export.py | 24 ++++++++------ .../ASR/pruned_transducer_stateless2/lstmp.py | 1 + .../scaling_converter.py | 1 + .../ASR/lstm_transducer_stateless2/export.py | 7 ++--- .../pruned_stateless_emformer_rnnt2/export.py | 1 + .../pruned_transducer_stateless/decoder.py | 2 +- .../export.py | 8 ++--- .../pruned_transducer_stateless5/export.py | 31 +++++++++---------- .../ASR/pruned_transducer_stateless5/lstmp.py | 1 + .../scaling_converter.py | 1 + .../pruned_transducer_stateless2/export.py | 23 +++++++------- .../pruned_transducer_stateless5/export.py | 25 +++++++-------- 15 files changed, 83 insertions(+), 69 deletions(-) mode change 100644 => 100755 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py create mode 120000 egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py create mode 120000 egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py create mode 120000 egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py old mode 100644 new mode 100755 index e348f7b2b..5179bfa1c --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -20,7 +21,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 29 \ --avg 19 @@ -45,12 +46,13 @@ import argparse import logging from pathlib import Path +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -85,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -122,10 +124,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -152,6 +158,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index b6190e8a6..4a44f7bcb 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,13 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +99,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +136,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -183,6 +186,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 5712da25e..aeed58dec 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -218,10 +218,9 @@ def export_decoder_model_jit_trace( decoder_filename: The filename to save the exported model. """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + # TODO(fangjun): Change the function name since we are actually using + # torch.jit.script instead of torch.jit.trace + traced_model = torch.jit.script(decoder_model) traced_model.save(decoder_filename) logging.info(f"Saved to {decoder_filename}") diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index ec2c9d580..e42a5c6ef 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -159,6 +159,7 @@ def main(): # Load id of the token and the vocab size params.blank_id = token_table[""] + params.unk_id = token_table[""] params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 03847b449..b961611f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -91,7 +91,7 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - embedding_out = self.embedding(y) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index 59a7eb589..67041012d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -87,7 +87,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e ln -s pretrained.pt epoch-999.pt ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --use-averaged-model False \ --epoch 999 \ --avg 1 \ @@ -113,7 +113,7 @@ cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/e ln -s pretrained.pt epoch-999.pt ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --use-averaged-model False \ --epoch 999 \ --avg 1 \ diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py index bc33dd160..0f6190a41 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py @@ -23,7 +23,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ + --tokens ./data/lang_char/tokens.txt \ --epoch 30 \ --avg 24 \ --use-averaged-model True @@ -50,8 +50,9 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -60,8 +61,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,13 +118,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -160,13 +157,14 @@ def main(): logging.info(f"device: {device}") - bpe_model = params.lang_dir + "/bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -256,6 +254,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py index 5d25daf5e..2f6ef488e 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py @@ -24,7 +24,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 10 \ --avg 2 \ --jit 1 @@ -47,7 +47,7 @@ for how to use them. ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 10 \ --avg 2 \ --jit-trace 1 @@ -63,7 +63,7 @@ Check ./jit_pretrained.py for usage. ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 10 \ --avg 2 @@ -91,14 +91,14 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -133,10 +133,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -313,10 +313,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py index cb541070e..5ff1f4a3b 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -20,7 +20,7 @@ Usage for offline: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 4 \ --avg 1 @@ -28,7 +28,7 @@ It will generate a file exp_dir/pretrained.pt for offline ASR. ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 4 \ --avg 1 \ --jit True @@ -38,7 +38,7 @@ It will generate a file exp_dir/cpu_jit.pt for offline ASR. Usage for streaming: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 7 \ --avg 1 @@ -46,7 +46,7 @@ It will generate a file exp_dir/pretrained.pt for streaming ASR. ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 7 \ --avg 1 \ --jit True @@ -73,13 +73,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -114,10 +114,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -152,10 +152,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) From 1c30847947f9d8b1416ef3e70408c07eab807f3d Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Sat, 27 Jan 2024 00:32:30 +0800 Subject: [PATCH 38/46] Whisper Fine-tuning Recipe on Aishell1 (#1466) * add decode seamlessm4t * add requirements * add decoding with avg model * add token files * add custom tokenizer * support deepspeed to finetune large model * support large-v3 * add model saving * using monkey patch to replace models * add manifest dir option --- egs/aishell/ASR/README.md | 7 + egs/aishell/ASR/RESULTS.md | 67 +- .../ASR/local/compute_fbank_aishell.py | 45 +- egs/aishell/ASR/prepare.sh | 13 + egs/aishell/ASR/whisper/asr_datamodule.py | 1 + egs/aishell/ASR/whisper/decode.py | 503 ++++++++++ egs/aishell/ASR/whisper/ds_config_zero1.json | 38 + egs/aishell/ASR/whisper/label_smoothing.py | 1 + egs/aishell/ASR/whisper/optim.py | 1 + egs/aishell/ASR/whisper/requirements.txt | 10 + egs/aishell/ASR/whisper/train.py | 927 ++++++++++++++++++ .../whisper_encoder_forward_monkey_patch.py | 29 + .../ASR/local/compute_fbank_musan.py | 59 +- icefall/dist.py | 2 +- 14 files changed, 1682 insertions(+), 21 deletions(-) create mode 120000 egs/aishell/ASR/whisper/asr_datamodule.py create mode 100755 egs/aishell/ASR/whisper/decode.py create mode 100644 egs/aishell/ASR/whisper/ds_config_zero1.json create mode 120000 egs/aishell/ASR/whisper/label_smoothing.py create mode 120000 egs/aishell/ASR/whisper/optim.py create mode 100755 egs/aishell/ASR/whisper/requirements.txt create mode 100755 egs/aishell/ASR/whisper/train.py create mode 100644 egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 176f065e5..b54719162 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -24,3 +24,10 @@ The following table lists the differences among them. The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + +# Whisper + +Recipe to finetune large pretrained models +| | Encoder | Decoder | Comment | +|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------| +| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index ff9504274..46d712fb2 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### Aishell training results (Fine-tuning Pretrained Models) +#### Whisper +[./whisper](./whisper) +##### fine-tuning results on Aishell test set on whisper medium, large-v2, large-v3 + +| | test (before fine-tuning) | test (after fine-tuning) | comment | +|------------------------|------|------|-----------------------------------------| +| medium | 7.23 | 3.27 | --epoch 10 --avg 4, ddp | +| large-v2 | 6.56 | 2.47 | --epoch 10 --avg 6, deepspeed zero stage1 | +| large-v3 | 6.06 | 2.84 | --epoch 5 --avg 3, deepspeed zero stage1 | + +Command for training is: +```bash +pip install -r whisper/requirements.txt + +./prepare.sh --stage 30 --stop_stage 30 + +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json + +# fine-tuning with ddp +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --base-lr 1e-5 \ + --model-name medium +``` + +Command for decoding using fine-tuned models: +```bash +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --beam-size 10 --max-duration 50 +``` +Command for decoding using pretrained models (before fine-tuning): +```bash +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 +``` +Fine-tuned models, training logs, decoding logs, tensorboard and decoding results +are available at + + ### Aishell training result (Stateless Transducer) #### Zipformer (Byte-level BPE) @@ -71,7 +129,7 @@ It's reworked Zipformer with Pruned RNNT loss. Command for training is: ```bash -./prepare.sh +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1" @@ -136,7 +194,7 @@ export CUDA_VISIBLE_DEVICES="0,1" --feedforward-dim 512,768,768,768,768,768 \ --encoder-dim 192,256,256,256,256,256 \ --encoder-unmasked-dim 192,192,192,192,192,192 \ - --max-duration 1200 + --max-duration 1200 ``` Command for decoding is: @@ -186,7 +244,7 @@ export CUDA_VISIBLE_DEVICES="0,1" --feedforward-dim 512,768,1536,2048,1536,768 \ --encoder-dim 192,256,512,768,512,256 \ --encoder-unmasked-dim 192,192,256,320,256,192 \ - --max-duration 800 + --max-duration 800 ``` Command for decoding is: @@ -202,7 +260,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do --num-encoder-layers 2,2,4,5,4,2 \ --feedforward-dim 512,768,1536,2048,1536,768 \ --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 + --encoder-unmasked-dim 192,192,256,320,256,192 done ``` @@ -755,7 +813,6 @@ python3 ./transducer_stateless/decode.py \ --max-sym-per-frame 3 ``` -### Aishell training results (Transducer-stateless) #### 2022-02-18 (Pingfeng Luo) : The tensorboard log for training is available at And pretrained model is available at diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index c7000da1c..3c48f0aa1 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,9 +49,14 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): +def compute_fbank_aishell( + num_mel_bins: int = 80, + perturb_speed: bool = False, + whisper_fbank: bool = False, + output_dir: str = "data/fbank", +): src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + output_dir = Path(output_dir) num_jobs = min(15, os.cpu_count()) dataset_parts = ( @@ -68,8 +80,12 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): list(manifests.keys()), dataset_parts, ) - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -82,7 +98,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False): supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -111,6 +127,18 @@ def get_args(): default=False, help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) return parser.parse_args() @@ -121,5 +149,8 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, ) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 9f73a2073..b7be89bc8 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -376,3 +376,16 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then --vocab-size 4336 \ --master-port 12345 fi + +# whisper large-v3 using 128 mel bins, others using 80 mel bins +whisper_mel_bins=80 +output_dir=data/fbank_whisper +if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then + log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning" + if [ ! -f $output_dir/.aishell.whisper.done ]; then + mkdir -p $output_dir + ./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir + ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir + touch $output_dir/.aishell.whisper.done + fi +fi diff --git a/egs/aishell/ASR/whisper/asr_datamodule.py b/egs/aishell/ASR/whisper/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/aishell/ASR/whisper/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py new file mode 100755 index 000000000..7f841dcb7 --- /dev/null +++ b/egs/aishell/ASR/whisper/decode.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Wei Kang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +# Command for decoding using fine-tuned models: +git lfs install +git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper +ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch 999 --avg 1 \ + --manifest-dir data/fbank_whisper \ + --beam-size 10 --max-duration 50 + +# Command for decoding using pretrained models (before fine-tuning): + +python3 ./whisper/decode.py \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --epoch -1 --avg 1 \ + --manifest-dir data/fbank_whisper \ + --remove-whisper-encoder-input-length-restriction False \ + --beam-size 10 --max-duration 50 + +""" + +import argparse +import logging +import re +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +import whisper +from asr_datamodule import AishellAsrDataModule +from tn.chinese.normalizer import Normalizer +from whisper.normalizers import BasicTextNormalizer +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward +from zhconv import convert + +from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def average_checkpoints( + filenames: List[Path], device: torch.device = torch.device("cpu") +) -> dict: + """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. + + Args: + filenames: + Filenames of the checkpoints to be averaged. We assume all + checkpoints are saved by :func:`save_checkpoint`. + device: + Move checkpoints to this device before averaging. + Returns: + Return a dict (i.e., state_dict) which is the average of all + model state dicts contained in the checkpoints. + """ + n = len(filenames) + + if "model" in torch.load(filenames[0], map_location=device): + avg = torch.load(filenames[0], map_location=device)["model"] + else: + avg = torch.load(filenames[0], map_location=device) + + # Identify shared parameters. Two parameters are said to be shared + # if they have the same data_ptr + uniqued: Dict[int, str] = dict() + + for k, v in avg.items(): + v_data_ptr = v.data_ptr() + if v_data_ptr in uniqued: + continue + uniqued[v_data_ptr] = k + + uniqued_names = list(uniqued.values()) + + for i in range(1, n): + if "model" in torch.load(filenames[i], map_location=device): + state_dict = torch.load(filenames[i], map_location=device)["model"] + else: + state_dict = torch.load(filenames[i], map_location=device) + for k in uniqued_names: + avg[k] += state_dict[k] + + for k in uniqued_names: + if avg[k].is_floating_point(): + avg[k] /= n + else: + avg[k] //= n + + return avg + + +def remove_punctuation(text: str or List[str]): + """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py + + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings without any punctuation. + """ + punctuation = "!,.;:?、!,。;:?《》 " + if isinstance(text, str): + text = re.sub(r"[{}]+".format(punctuation), "", text).strip() + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = re.sub(r"[{}]+".format(punctuation), "", t).strip() + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type {type(text)}") + + +def to_simple(text: str or List[str]): + """Convert traditional Chinese to simplified Chinese. + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings converted to simplified Chinese. + """ + if isinstance(text, str): + text = convert(text, "zh-cn") + return text + elif isinstance(text, list): + result_text = [] + for t in text: + t = convert(t, "zh-cn") + result_text.append(t) + return result_text + else: + raise Exception(f"Not support type{type(text)}") + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=-1, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="beam-search", + help="""Decoding method. + Supported values are: + - beam-search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="whisper/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--remove-whisper-encoder-input-length-restriction", + type=str2bool, + default=True, + help="replace whisper encoder forward method to remove input length restriction", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +) -> Dict[str, List[List[int]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. + Args: + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. + Returns: + Return a dict, whose key may be "beam-search". + """ + dtype = torch.float16 + device = torch.device("cuda") + + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device, dtype=dtype).transpose(1, 2) + if not params.remove_whisper_encoder_input_length_restriction: + T = 3000 + if feature.shape[2] < T: + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) + + supervisions = batch["supervisions"] + feature_len = supervisions["num_frames"] + feature_len = feature_len.to(device, dtype=dtype) + results = model.decode(feature, params.decoding_options) + hyps = [result.text for result in results] + + hyps = remove_punctuation(hyps) + hyps = to_simple(hyps) + hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + + return {"beam-search": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "beam-search". + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) + + options = whisper.DecodingOptions( + task="transcribe", + language="zh", + without_timestamps=True, + beam_size=params.beam_size, + ) + params.decoding_options = options + params.cleaner = BasicTextNormalizer() + params.normalizer = Normalizer() + + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda") + + logging.info(f"device: {device}") + + if params.remove_whisper_encoder_input_length_restriction: + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + if params.epoch > 0: + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) + else: + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + test_dl = aishell.test_dataloaders(aishell.test_cuts()) + test_sets = ["valid", "test"] + test_dls = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json new file mode 100644 index 000000000..bf8cc0452 --- /dev/null +++ b/egs/aishell/ASR/whisper/ds_config_zero1.json @@ -0,0 +1,38 @@ +{ + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 100, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 0.01 + }, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-5 + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 100 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 5, + "steps_per_print": 50, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false +} diff --git a/egs/aishell/ASR/whisper/label_smoothing.py b/egs/aishell/ASR/whisper/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/aishell/ASR/whisper/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/aishell/ASR/whisper/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt new file mode 100755 index 000000000..0708f2344 --- /dev/null +++ b/egs/aishell/ASR/whisper/requirements.txt @@ -0,0 +1,10 @@ +k2 +kaldialign +git+https://github.com/lhotse-speech/lhotse +sentencepiece +tensorboard +librosa +git+https://github.com/yuekaizhang/whisper.git +zhconv +WeTextProcessing +deepspeed diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py new file mode 100755 index 000000000..d16793eb2 --- /dev/null +++ b/egs/aishell/ASR/whisper/train.py @@ -0,0 +1,927 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +#fine-tuning with deepspeed zero stage 1 +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_large_v2 \ + --model-name large-v2 \ + --manifest-dir data/fbank_whisper \ + --deepspeed \ + --deepspeed_config ./whisper/ds_config_zero1.json + +# fine-tuning with ddp +torchrun --nproc-per-node 8 ./whisper/train.py \ + --max-duration 200 \ + --exp-dir whisper/exp_medium \ + --manifest-dir data/fbank_whisper \ + --base-lr 1e-5 \ + --model-name medium +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import deepspeed +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import whisper +from asr_datamodule import AishellAsrDataModule +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.functional import pad as pad_tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import update_averaged_model +from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--model-name", + type=str, + default="large-v2", + choices=["large-v2", "large-v3", "medium", "small", "tiny"], + help="""The model name to use. + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=1e-5, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser = deepspeed.add_config_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - frame_shift_ms: The frame shift in milliseconds. + - allowed_excess_duration_ratio: The allowed excess duration ratio. + - best_train_loss: The best training loss so far. + - best_valid_loss: The best validation loss so far. + - best_train_epoch: The epoch where the best training loss is achieved. + - best_valid_epoch: The epoch where the best validation loss is achieved. + - batch_idx_train: The batch index of the current batch. + - log_interval: Log training stats every `log_interval` batches. + - reset_interval: Reset the stats every `reset_interval` batches. + - valid_interval: Run validation every `valid_interval` batches. + - env_info: The environment information. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "subsampling_factor": 2, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 5000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute the loss for the given batch. + Args: + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: + padding_size = max(tensor.shape[0] for tensor in tensors) + dims = len(tensors[0].shape) + padded_tensors = [] + for tensor in tensors: + padding = [0] * 2 * dims + padding[-1] = padding_size - tensor.shape[0] + padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value)) + return torch.stack([tensor for tensor in padded_tensors], dim=0) + + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + + assert feature.ndim == 3 + feature = feature.to(device) + feature = feature.transpose(1, 2) # (N, C, T) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + + texts = batch["supervisions"]["text"] + # remove spaces in texts + texts = [text.replace(" ", "") for text in texts] + + text_tokens_list = [ + list(tokenizer.sot_sequence_including_notimestamps) + + tokenizer.encode(text) + + [tokenizer.eot] + for text in texts + ] + # convert it to torch tensor + text_tokens_list = [ + torch.LongTensor(text_tokens) for text_tokens in text_tokens_list + ] + + # 50256 is the index of for all whisper models + prev_outputs_tokens = _batch_tensors( + [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 + ) + target_tokens = _batch_tensors( + [tokens[1:] for tokens in text_tokens_list], pad_value=50256 + ) + target_lengths = torch.LongTensor( + [tokens.shape[0] - 1 for tokens in text_tokens_list] + ) + + decoder_criterion = LabelSmoothingLoss( + ignore_index=50256, label_smoothing=0.1, reduction="sum" + ) + + # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> + ignore_prefix_size = 3 + with torch.set_grad_enabled(is_training): + encoder_out = model.encoder(feature) + text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) + text_logits = text_logits[:, ignore_prefix_size:, :] + target_tokens = target_tokens[:, ignore_prefix_size:] + loss = decoder_criterion(text_logits, target_tokens.to(device)) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + tokenizer=tokenizer, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + if params.deepspeed: + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + else: + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + and not params.deepspeed + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + if batch_idx % params.log_interval == 0: + try: + cur_lr = scheduler.get_last_lr()[0] + except: # noqa + cur_lr = 0.0 + cur_grad_scale = ( + scaler._scale.item() + if (params.use_fp16 and not params.deepspeed) + else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if (params.use_fp16 and not params.deepspeed) + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info(params) + + logging.info("About to create model") + + replace_whisper_encoder_forward() + model = whisper.load_model(params.model_name, "cpu") + del model.alignment_heads + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + tokenizer = whisper.tokenizer.get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language="zh", + task="transcribe", + ) + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") + model.to(device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if world_size > 1: + if params.deepspeed: + logging.info("Using DeepSpeed") + model, optimizer, _, scheduler = deepspeed.initialize( + args=params, model=model, model_parameters=model.parameters() + ) + else: + logging.info("Using DDP") + setup_dist(use_ddp_launch=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + aishell = AishellAsrDataModule(args) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + if not params.deepspeed: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + tokenizer=tokenizer, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}", + ) + else: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1 and not params.deepspeed: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = get_world_size() + rank = get_rank() + + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + run(rank=rank, world_size=world_size, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py new file mode 100644 index 000000000..5bfbdce3b --- /dev/null +++ b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -0,0 +1,29 @@ +import torch +import torch.nn.functional as F +import whisper + + +def forward(self, x: torch.Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +def replace_whisper_encoder_forward(): + """ + This function monkey patches the forward method of the whisper encoder. + To be called before the model is loaded, it changes whisper to process audio with any length < 30s. + """ + whisper.model.AudioEncoder.forward = forward diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 62036467e..d7781687f 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ - +import argparse import logging import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + MonoCut, + WhisperFbank, + WhisperFbankConfig, + combine, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -45,11 +54,12 @@ def is_cut_long(c: MonoCut) -> bool: return c.duration > 5 -def compute_fbank_musan(): +def compute_fbank_musan( + num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank" +): src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + output_dir = Path(output_dir) num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 dataset_parts = ( "music", @@ -81,7 +91,12 @@ def compute_fbank_musan(): logging.info("Extracting features for Musan") - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + if whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) with get_executor() as ex: # Initialize the executor only once. # create chunks of Musan with duration 5 - 10 seconds @@ -102,8 +117,36 @@ def compute_fbank_musan(): musan_cuts.to_file(musan_cuts_path) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bins for Fbank""", + ) + parser.add_argument( + "--whisper-fbank", + type=str2bool, + default=False, + help="Use WhisperFbank instead of Fbank. Default: False.", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="Output directory. Default: data/fbank.", + ) + return parser.parse_args() + + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() + args = get_args() + compute_fbank_musan( + num_mel_bins=args.num_mel_bins, + whisper_fbank=args.whisper_fbank, + output_dir=args.output_dir, + ) diff --git a/icefall/dist.py b/icefall/dist.py index 922f31a2f..ee76e994a 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -22,7 +22,7 @@ from torch import distributed as dist def setup_dist( - rank, world_size, master_port=None, use_ddp_launch=False, master_addr=None + rank=None, world_size=None, master_port=None, use_ddp_launch=False, master_addr=None ): """ rank and world_size are used only if use_ddp_launch is False. From 37b975cac9ac237b71e957dd0770b091124fcb60 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 27 Jan 2024 06:41:56 +0800 Subject: [PATCH 39/46] fixed a CI test for `wenetspeech` (#1476) * Comply to issue #1149 https://github.com/k2-fsa/icefall/issues/1149 --- ...enetspeech-pruned-transducer-stateless2.sh | 6 ++--- .../pruned_transducer_stateless2/export.py | 19 +++++++------- .../pruned_transducer_stateless3/export.py | 19 +++++++------- .../export-onnx.py | 21 +++++++--------- .../ASR/transducer_stateless/export.py | 19 +++++++------- .../transducer_stateless_modified-2/export.py | 18 ++++++------- .../transducer_stateless_modified/export.py | 19 +++++++------- .../pruned_transducer_stateless5/export.py | 20 +++++++-------- .../pruned_transducer_stateless5/export.py | 19 ++++++-------- .../pruned_transducer_stateless2/export.py | 19 +++++++------- .../pruned_transducer_stateless7/export.py | 23 ++++++++--------- .../export-onnx-zh.py | 18 ++++++------- egs/swbd/ASR/conformer_ctc/export.py | 19 +++++++------- egs/tedlium3/ASR/conformer_ctc2/export.py | 18 ++++++------- .../export-onnx.py | 25 ++++++++----------- .../export-onnx-streaming.py | 19 +++++++------- .../export-onnx.py | 18 ++++++------- 17 files changed, 150 insertions(+), 169 deletions(-) diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index a3a2d3080..981b74b76 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -30,7 +30,7 @@ log "Test exporting to ONNX format" ./pruned_transducer_stateless2/export-onnx.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 @@ -38,14 +38,14 @@ log "Export to torchscript model" ./pruned_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index 2ce5cfe69..c2dc0d5f3 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -136,10 +136,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 723414167..2248c7a08 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -47,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -153,10 +153,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 params.datatang_prob = 0 logging.info(params) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index 39d988cd0..4981fb71a 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -49,14 +49,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder2 import Decoder +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer from icefall.checkpoint import ( @@ -65,8 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,12 +122,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -404,9 +401,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 01de5d772..bfd0ecb0c 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -23,7 +23,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,8 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -192,10 +192,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index c1081c32b..4f2c71d18 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -46,6 +46,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -99,10 +100,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -190,10 +191,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 3e14ad69c..487748947 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -46,6 +46,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -55,8 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -99,10 +99,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -190,10 +190,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index 8a5be94d0..c92c7ab83 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char + --tokens ./data/lang_char/tokens.txt \ --epoch 25 \ --avg 5 @@ -48,6 +48,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -154,10 +154,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.unk_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index bf9856c60..246820833 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -48,6 +48,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,13 +115,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -157,9 +154,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8e5cc6075..5dc73c52b 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -20,7 +20,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens ./data/lang_char/tokens.txt \ --epoch 29 \ --avg 18 @@ -45,12 +45,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -85,10 +85,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -122,10 +122,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py index 23a88dd29..8bafaef44 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_char/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,9 +86,8 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch -import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -199,10 +197,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index 2a52e2eec..1ce770128 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx-zh.py \ - --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ + --tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \ --use-averaged-model 1 \ --epoch 11 \ --avg 1 \ @@ -55,6 +55,7 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -70,8 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -441,9 +441,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py index 1bb6277ad..44b2e95d6 100755 --- a/egs/swbd/ASR/conformer_ctc/export.py +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -23,12 +23,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +63,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -105,9 +104,9 @@ def main(): logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 device = torch.device("cpu") if torch.cuda.is_available(): @@ -119,7 +118,7 @@ def main(): num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim, - num_classes=num_classes, + num_classes=params.vocab_size, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py index 009bea230..b5bf911c2 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/export.py +++ b/egs/tedlium3/ASR/conformer_ctc2/export.py @@ -45,6 +45,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from scaling_converter import convert_scaled_to_non_scaled @@ -56,8 +57,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser() -> argparse.ArgumentParser: @@ -118,10 +118,10 @@ def get_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,9 +166,9 @@ def main(): params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 device = torch.device("cpu") if torch.cuda.is_available(): @@ -182,7 +182,7 @@ def main(): model = Conformer( num_features=params.feature_dim, - num_classes=num_classes, + num_classes=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.dim_model, nhead=params.nhead, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index 140b1d37f..8aea79fe3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless2/export-onnx.py \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --exp-dir $repo/exp @@ -48,6 +48,7 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -57,14 +58,8 @@ from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -110,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -397,9 +392,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 921766ad4..30068d01a 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -from icefall.lexicon import Lexicon import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.lexicon import Lexicon +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +132,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -490,9 +491,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py index 037c7adf1..1c9eb8648 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,6 +55,7 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -70,8 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -417,9 +417,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) From b07d5472c58eb11bc86cdcf787ecfdb634a55c73 Mon Sep 17 00:00:00 2001 From: Henry Li Xinyuan <57513260+HSTEHSTEHSTE@users.noreply.github.com> Date: Wed, 31 Jan 2024 09:53:36 -0500 Subject: [PATCH 40/46] Implement recipe for Fluent Speech Commands dataset (#1469) --------- Signed-off-by: Xinyuan Li --- egs/fluent_speech_commands/SLU/README.md | 9 + .../SLU/local/compile_hlg.py | 136 ++++ .../SLU/local/compute_fbank_slu.py | 97 +++ .../SLU/local/generate_lexicon.py | 59 ++ .../SLU/local/prepare_lang.py | 371 +++++++++++ egs/fluent_speech_commands/SLU/prepare.sh | 103 +++ egs/fluent_speech_commands/SLU/shared | 1 + .../SLU/transducer/__init__.py | 0 .../SLU/transducer/beam_search.py | 71 ++ .../SLU/transducer/conformer.py | 1 + .../SLU/transducer/decode.py | 346 ++++++++++ .../SLU/transducer/decoder.py | 1 + .../SLU/transducer/encoder_interface.py | 1 + .../SLU/transducer/joiner.py | 1 + .../SLU/transducer/model.py | 1 + .../SLU/transducer/slu_datamodule.py | 289 ++++++++ .../SLU/transducer/subsampling.py | 1 + .../SLU/transducer/test_conformer.py | 1 + .../SLU/transducer/test_decoder.py | 1 + .../SLU/transducer/test_joiner.py | 1 + .../SLU/transducer/test_transducer.py | 1 + .../SLU/transducer/train.py | 625 ++++++++++++++++++ .../SLU/transducer/transformer.py | 1 + icefall/shared/make_kn_lm.py | 9 +- 24 files changed, 2124 insertions(+), 3 deletions(-) create mode 100755 egs/fluent_speech_commands/SLU/README.md create mode 100755 egs/fluent_speech_commands/SLU/local/compile_hlg.py create mode 100755 egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py create mode 100755 egs/fluent_speech_commands/SLU/local/generate_lexicon.py create mode 100755 egs/fluent_speech_commands/SLU/local/prepare_lang.py create mode 100755 egs/fluent_speech_commands/SLU/prepare.sh create mode 120000 egs/fluent_speech_commands/SLU/shared create mode 100755 egs/fluent_speech_commands/SLU/transducer/__init__.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/beam_search.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/conformer.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/decode.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/decoder.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/encoder_interface.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/joiner.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/model.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/subsampling.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_conformer.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_decoder.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_joiner.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/test_transducer.py create mode 100755 egs/fluent_speech_commands/SLU/transducer/train.py create mode 120000 egs/fluent_speech_commands/SLU/transducer/transformer.py diff --git a/egs/fluent_speech_commands/SLU/README.md b/egs/fluent_speech_commands/SLU/README.md new file mode 100755 index 000000000..a203a9bfb --- /dev/null +++ b/egs/fluent_speech_commands/SLU/README.md @@ -0,0 +1,9 @@ +## Fluent Speech Commands recipe + +This is a recipe for the Fluent Speech Commands dataset, a speech dataset which transcribes short utterances (such as "turn the lights on in the kitchen") into action frames (such as {"action": "activate", "object": "lights", "location": "kitchen"}). The training set contains 23,132 utterances, whereas the test set contains 3793 utterances. + +Dataset Paper link: + +cd icefall/egs/fluent_speech_commands/ +Training: python transducer/train.py +Decoding: python transducer/decode.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/local/compile_hlg.py b/egs/fluent_speech_commands/SLU/local/compile_hlg.py new file mode 100755 index 000000000..a7df8f966 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/compile_hlg.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + logging.info("Loading G.fst.txt") + with open(lang_dir / "G.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + # LG.labels[LG.labels >= first_token_disambig_id] = 0 + # see https://github.com/k2-fsa/k2/pull/1140 + labels = LG.labels + labels[labels >= first_token_disambig_id] = 0 + LG.labels = labels + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py new file mode 100755 index 000000000..a51b7b47b --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + +""" +This file computes fbank features of the Fluent Speech Commands dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or it wastes a +# lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_slu(manifest_dir, fbanks_dir): + src_dir = Path(manifest_dir) + output_dir = Path(fbanks_dir) + + # This dataset is rather small, so we use only one job + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 23 + + dataset_parts = ( + "train", + "valid", + "test", + ) + prefix = "slu" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(sampling_rate=16000, num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_file = output_dir / f"{prefix}_cuts_{partition}.{suffix}" + if cuts_file.is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 1, # use one job + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(cuts_file) + + +parser = argparse.ArgumentParser() +parser.add_argument("manifest_dir") +parser.add_argument("fbanks_dir") + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + args = parser.parse_args() + + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_slu(args.manifest_dir, args.fbanks_dir) diff --git a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py new file mode 100755 index 000000000..6263e062f --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py @@ -0,0 +1,59 @@ +import argparse + +import pandas +from tqdm import tqdm + + +def generate_lexicon(corpus_dir, lm_dir): + data = pandas.read_csv( + str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0 + ) + vocab_transcript = set() + vocab_frames = set() + transcripts = data["transcription"].tolist() + frames = list( + i + for i in zip( + data["action"].tolist(), data["object"].tolist(), data["location"].tolist() + ) + ) + + for transcript in tqdm(transcripts): + for word in transcript.split(): + vocab_transcript.add(word) + + for frame in tqdm(frames): + for word in frame: + vocab_frames.add("_".join(word.split())) + + with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file: + lexicon_transcript_file.write(" 1" + "\n") + lexicon_transcript_file.write(" 2" + "\n") + lexicon_transcript_file.write(" 0" + "\n") + id = 3 + for vocab in vocab_transcript: + lexicon_transcript_file.write(vocab + " " + str(id) + "\n") + id += 1 + + with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file: + lexicon_frames_file.write(" 1" + "\n") + lexicon_frames_file.write(" 2" + "\n") + lexicon_frames_file.write(" 0" + "\n") + id = 3 + for vocab in vocab_frames: + lexicon_frames_file.write(vocab + " " + str(id) + "\n") + id += 1 + + +parser = argparse.ArgumentParser() +parser.add_argument("corpus_dir") +parser.add_argument("lm_dir") + + +def main(): + args = parser.parse_args() + + generate_lexicon(args.corpus_dir, args.lm_dir) + + +main() diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py new file mode 100755 index 000000000..2a71dcf81 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/local/prepare_lang.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon + +Lexicon = List[Tuple[str, List[str]]] + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "!SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + # assert token2id[""] == 0 + # assert word2id[""] == 0 + + eps = 0 + sil_token = word2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [word2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = word2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +parser = argparse.ArgumentParser() +parser.add_argument("lm_dir") + + +def main(): + args = parser.parse_args() + + out_dir = Path(args.lm_dir) + lexicon_filenames = [out_dir / "words_frames.txt", out_dir / "words_transcript.txt"] + names = ["frames", "transcript"] + sil_token = "!SIL" + sil_prob = 0.5 + + for name, lexicon_filename in zip(names, lexicon_filenames): + lexicon = read_lexicon(lexicon_filename) + tokens = get_words(lexicon) + words = get_words(lexicon) + new_lexicon = [] + for lexicon_item in lexicon: + new_lexicon.append((lexicon_item[0], [lexicon_item[0]])) + lexicon = new_lexicon + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + tokens = [""] + tokens + words = ["eps"] + words + ["#0", "!SIL"] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id) + write_mapping(out_dir / ("words_" + name + ".txt"), word2id) + write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=word2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=word2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt")) + torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt")) + + if False: + # Just for debugging, will remove it + L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") + L_disambig.labels_sym = L.labels_sym + L_disambig.aux_labels_sym = L.aux_labels_sym + L.draw(out_dir / "L.png", title="L") + L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") + + +main() diff --git a/egs/fluent_speech_commands/SLU/prepare.sh b/egs/fluent_speech_commands/SLU/prepare.sh new file mode 100755 index 000000000..3ff339d91 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/prepare.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=1 +stop_stage=5 + +data_dir=path/to/fluent/speech/commands +target_root_dir=data/ + +lang_dir=${target_root_dir}/lang_phone +lm_dir=${target_root_dir}/lm +manifest_dir=${target_root_dir}/manifests +fbanks_dir=${target_root_dir}/fbanks + +. shared/parse_options.sh || exit 1 + +mkdir -p $lang_dir +mkdir -p $lm_dir + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "data_dir: $data_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare slu manifest" + mkdir -p $manifest_dir + lhotse prepare slu $data_dir $manifest_dir +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute fbank for SLU" + mkdir -p $fbanks_dir + python ./local/compute_fbank_slu.py $manifest_dir $fbanks_dir +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare lang" + # NOTE: " SIL" is added for implementation convenience + # as the graph compiler code requires that there is a OOV word + # in the lexicon. + python ./local/generate_lexicon.py $data_dir $lm_dir +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Train LM" + # We use a unigram G + ./shared/make_kn_lm.py \ + -ngram-order 1 \ + -text $lm_dir/words_transcript.txt \ + -lm $lm_dir/G_transcript.arpa + + ./shared/make_kn_lm.py \ + -ngram-order 1 \ + -text $lm_dir/words_frames.txt \ + -lm $lm_dir/G_frames.arpa + + python ./local/prepare_lang.py $lm_dir + + if [ ! -f $lm_dir/G_transcript.fst.txt ]; then + python -m kaldilm \ + --read-symbol-table="$lm_dir/words_transcript.txt" \ + $lm_dir/G_transcript.arpa > $lm_dir/G_transcript.fst.txt + fi + + if [ ! -f $lm_dir/G_frames.fst.txt ]; then + python -m kaldilm \ + --read-symbol-table="$lm_dir/words_frames.txt" \ + $lm_dir/G_frames.arpa > $lm_dir/G_frames.fst.txt + fi + + mkdir -p $lm_dir/frames + mkdir -p $lm_dir/transcript + + chmod -R +777 . + + for i in G_frames.arpa G_frames.fst.txt L_disambig_frames.pt L_frames.pt lexicon_disambig_frames.txt tokens_frames.txt words_frames.txt; + do + j=${i//"_frames"/} + mv "$lm_dir/$i" $lm_dir/frames/$j + done + + for i in G_transcript.arpa G_transcript.fst.txt L_disambig_transcript.pt L_transcript.pt lexicon_disambig_transcript.txt tokens_transcript.txt words_transcript.txt; + do + j=${i//"_transcript"/} + mv "$lm_dir/$i" $lm_dir/transcript/$j + done +fi + + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compile HLG" + ./local/compile_hlg.py --lang-dir $lm_dir/frames + ./local/compile_hlg.py --lang-dir $lm_dir/transcript + +fi diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared new file mode 120000 index 000000000..32a374a7f --- /dev/null +++ b/egs/fluent_speech_commands/SLU/shared @@ -0,0 +1 @@ +../../icefall/shared/ \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/__init__.py b/egs/fluent_speech_commands/SLU/transducer/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/egs/fluent_speech_commands/SLU/transducer/beam_search.py b/egs/fluent_speech_commands/SLU/transducer/beam_search.py new file mode 100755 index 000000000..a16aa0123 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/beam_search.py @@ -0,0 +1,71 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from transducer.model import Transducer + + +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, id2word: dict +) -> List[str]: + """ + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + device = model.device + + sos = torch.tensor([blank_id], device=device).reshape(1, 1) + decoder_out, (h, c) = model.decoder(sos) + T = encoder_out.size(1) + t = 0 + hyp = [] + max_u = 1000 # terminate after this number of steps + u = 0 + + while t < T and u < max_u: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # fmt: on + logits = model.joiner(current_encoder_out, decoder_out) + + log_prob = logits.log_softmax(dim=-1) + # log_prob is (N, 1, 1) + # TODO: Use logits.argmax() + y = log_prob.argmax() + if y != blank_id: + hyp.append(y.item()) + y = y.reshape(1, 1) + decoder_out, (h, c) = model.decoder(y, (h, c)) + u += 1 + else: + t += 1 + # id2word = {1: "YES", 2: "NO"} + + hyp = [id2word[i] for i in hyp] + + return hyp diff --git a/egs/fluent_speech_commands/SLU/transducer/conformer.py b/egs/fluent_speech_commands/SLU/transducer/conformer.py new file mode 120000 index 000000000..8be0dc864 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/decode.py b/egs/fluent_speech_commands/SLU/transducer/decode.py new file mode 100755 index 000000000..ba2b9aaea --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/decode.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from transducer.beam_search import greedy_search +from transducer.conformer import Conformer +from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer +from transducer.slu_datamodule import SluDataModule + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_id2word(params): + id2word = {} + + # 0 is blank + id = 1 + try: + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if len(line.strip()) > 0: + id2word[id] = line.split()[0] + id += 1 + except: + pass + + return id2word + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=6, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory from which to load the checkpoints", + ) + parser.add_argument("--lang-dir", type=str, default="data/lm/frames") + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "lang_dir": Path("data/lm/frames"), + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + } + ) + + vocab_size = 1 + with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if ( + len(line.strip()) > 0 + ): # and '' not in line and '' not in line and '' not in line: + vocab_size += 1 + params.vocab_size = vocab_size + + return params + + +def decode_one_batch( + params: AttributeDict, model: nn.Module, batch: dict, id2word: dict +) -> List[List[int]]: + """Decode one batch and return the result in a list-of-list. + Each sub list contains the word IDs for an utterance in the batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding. + - params.method is "nbest", it uses nbest decoding. + + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) + Returns: + Return the decoding result. `len(ans)` == batch size. + """ + device = model.device + feature = batch["inputs"] + feature = feature.to(device) + # at entry, feature is (N, T, C) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + hyp = greedy_search(model=model, encoder_out=encoder_out_i, id2word=id2word) + hyps.append(hyp) + + # hyps = [[word_table[i] for i in ids] for ids in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> List[Tuple[List[int], List[int]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a tuple contains two elements (ref_text, hyp_text): + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + id2word = get_id2word(params) + + results = [] + for batch_idx, batch in enumerate(dl): + texts = [ + " ".join(a.supervisions[0].custom["frames"]) + for a in batch["supervisions"]["cut"] + ] + texts = [ + " " + a.replace("change language", "change_language") + " " + for a in texts + ] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, model=model, batch=batch, id2word=id2word + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + exp_dir: Path, + test_set_name: str, + results: List[Tuple[List[int], List[int]]], +) -> None: + """Save results to `exp_dir`. + Args: + exp_dir: + The output directory. This function create the following files inside + this directory: + + - recogs-{test_set_name}.text + + It contains the reference and hypothesis results, like below:: + + ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + + - errs-{test_set_name}.txt + + It contains the detailed WER. + test_set_name: + The name of the test set, which will be part of the result filename. + results: + A list of tuples, each of which contains (ref_words, hyp_words). + Returns: + Return None. + """ + recog_path = exp_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + write_error_stats(f, f"{test_set_name}", results) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + +def get_transducer_model(params: AttributeDict): + # encoder = Tdnn( + # num_features=params.feature_dim, + # output_dim=params.hidden_dim, + # ) + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + return transducer + + +@torch.no_grad() +def main(): + parser = get_parser() + SluDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params["env_info"] = get_env_info() + + setup_logger(f"{params.exp_dir}/log/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + model.device = device + + # we need cut ids to display recognition results. + args.return_cuts = True + slu = SluDataModule(args) + test_dl = slu.test_dataloaders() + results = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + test_set_name = str(args.feature_dir).split("/")[-2] + save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/fluent_speech_commands/SLU/transducer/decoder.py b/egs/fluent_speech_commands/SLU/transducer/decoder.py new file mode 120000 index 000000000..e99310f91 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/decoder.py @@ -0,0 +1 @@ +../../../yesno/ASR/transducer/decoder.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/joiner.py b/egs/fluent_speech_commands/SLU/transducer/joiner.py new file mode 120000 index 000000000..75fa64868 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/joiner.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/model.py b/egs/fluent_speech_commands/SLU/transducer/model.py new file mode 120000 index 000000000..10f6ddad1 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/model.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py new file mode 100755 index 000000000..fa715abdd --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py @@ -0,0 +1,289 @@ +# Copyright 2021 Piotr Żelasko +# 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class SluDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbanks"), + help="Path to directory with train/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=30.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=False, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=10, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to create train dataset") + transforms = [] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + FbankConfig(sampling_rate=8000, num_mel_bins=23) + ), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get valid cuts") + cuts_valid = self.valid_cuts() + + logging.debug("About to create valid dataset") + valid = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create valid dataloader") + valid_dl = DataLoader( + valid, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return valid_dl + + def test_dataloaders(self) -> DataLoader: + logging.info("About to get test cuts") + cuts_test = self.test_cuts() + + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_test, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_train.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> List[CutSet]: + logging.info("About to get valid cuts") + cuts_valid = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_valid.jsonl.gz" + ) + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz") + return cuts_test diff --git a/egs/fluent_speech_commands/SLU/transducer/subsampling.py b/egs/fluent_speech_commands/SLU/transducer/subsampling.py new file mode 120000 index 000000000..fd7ca8b30 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_conformer.py b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py new file mode 120000 index 000000000..3060dd70c --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/test_conformer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_decoder.py b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py new file mode 120000 index 000000000..d1bc718ce --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_decoder.py @@ -0,0 +1 @@ +../../../yesno/ASR/transducer/test_decoder.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_joiner.py b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py new file mode 120000 index 000000000..248222a8a --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/test_joiner.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/test_transducer.py b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py new file mode 120000 index 000000000..df104bad7 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/test_transducer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer/test_transducer.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/SLU/transducer/train.py b/egs/fluent_speech_commands/SLU/transducer/train.py new file mode 100755 index 000000000..a59c0b754 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/train.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import List, Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from lhotse.utils import fix_random_seed +from slu_datamodule import SluDataModule +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from transducer.conformer import Conformer + +# from torch.utils.tensorboard import SummaryWriter +from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_word2id(params): + word2id = {} + + # 0 is blank + id = 1 + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if len(line.strip()) > 0: + word2id[line.split()[0]] = id + id += 1 + + return word2id + + +def get_labels(texts: List[str], word2id) -> k2.RaggedTensor: + """ + Args: + texts: + A list of transcripts. + Returns: + Return a ragged tensor containing the corresponding word ID. + """ + # blank is 0 + word_ids = [] + for t in texts: + words = t.split() + ids = [word2id[w] for w in words] + word_ids.append(ids) + + return k2.RaggedTensor(word_ids) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=7, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + tdnn/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory to save results", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument("--lang-dir", type=str, default="data/lm/frames") + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + + """ + params = AttributeDict( + { + "lr": 1e-4, + "feature_dim": 23, + "weight_decay": 1e-6, + "start_epoch": 0, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 20, + "valid_interval": 3000, + "exp_dir": Path("transducer/exp"), + "lang_dir": Path("data/lm/frames"), + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + } + ) + + vocab_size = 1 + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: + for line in lexicon_file: + if ( + len(line.strip()) > 0 + ): # and '' not in line and '' not in line and '' not in line: + vocab_size += 1 + params.vocab_size = vocab_size + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Tdnn in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + feature_lens = batch["supervisions"]["num_frames"].to(device) + + texts = [ + " ".join(a.supervisions[0].custom["frames"]) + for a in batch["supervisions"]["cut"] + ] + texts = [ + " " + a.replace("change language", "change_language") + " " + for a in texts + ] + labels = get_labels(texts, word2ids).to(device) + + with torch.set_grad_enabled(is_training): + loss = model(x=feature, x_lens=feature_lens, y=labels) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = feature.size(0) + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + word2ids, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + word2ids=word2ids, + ) + assert loss.requires_grad is False + + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + word2ids, + tb_writer: None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, model=model, batch=batch, is_training=True, word2ids=word2ids + ) + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + word2ids=word2ids, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def get_transducer_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + return transducer + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + + params.update(vars(args)) + params["env_info"] = get_env_info() + + word2ids = get_word2id(params) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + # if args.tensorboard and rank == 0: + # tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + # else: + # tb_writer = None + tb_writer = None + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"device: {device}") + + model = get_transducer_model(params) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + slu = SluDataModule(args) + train_dl = slu.train_dataloaders() + + # There are only 60 waves: 30 files are used for training + # and the remaining 30 files are used for testing. + # We use test data as validation. + valid_dl = slu.test_dataloaders() + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + word2ids=word2ids, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=None, + rank=rank, + ) + + logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + SluDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/fluent_speech_commands/SLU/transducer/transformer.py b/egs/fluent_speech_commands/SLU/transducer/transformer.py new file mode 120000 index 000000000..214afed39 --- /dev/null +++ b/egs/fluent_speech_commands/SLU/transducer/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index 231aca7f1..42ed44fdd 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -33,7 +33,7 @@ parser.add_argument( "-ngram-order", type=int, default=4, - choices=[2, 3, 4, 5, 6, 7], + choices=[1, 2, 3, 4, 5, 6, 7], help="Order of n-gram", ) parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") @@ -105,7 +105,7 @@ class NgramCounts: # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. def __init__(self, ngram_order, bos_symbol="", eos_symbol=""): - assert ngram_order >= 2 + assert ngram_order >= 1 self.ngram_order = ngram_order self.bos_symbol = bos_symbol @@ -169,7 +169,10 @@ class NgramCounts: with open(filename, encoding=default_encoding) as fp: for line in fp: line = line.strip(strip_chars) - self.add_raw_counts_from_line(line) + if self.ngram_order == 1: + self.add_raw_counts_from_line(line.split()[0]) + else: + self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print( From b9e6327adfe0def4989ad508d0df8f6347da2c0f Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Sat, 3 Feb 2024 07:25:27 +0900 Subject: [PATCH 41/46] Fixing torch.ctc err (#1485) * fixing torch.ctc err * Move targets & lengths to CPU --- egs/librispeech/ASR/zipformer/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2f86af47..73009d35c 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -164,9 +164,9 @@ class AsrModel(nn.Module): ctc_loss = torch.nn.functional.ctc_loss( log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets, - input_lengths=encoder_out_lens, - target_lengths=target_lengths, + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), reduction="sum", ) return ctc_loss From a813186f6463b4d3f6d73460dd1a57b2e975b803 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 5 Feb 2024 12:47:52 +0800 Subject: [PATCH 42/46] minor fix for docstr and default param. (#1490) * Update train.py and README.md --- README.md | 3 +++ egs/aishell/ASR/whisper/train.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cc817702b..770066166 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,9 @@ The [LibriSpeech][librispeech] recipe supports the most comprehensive set of mod - LSTM-based Predictor - [Stateless Predictor](https://research.google/pubs/rnn-transducer-with-stateless-prediction-network/) +#### Whisper + - [OpenAi Whisper](https://arxiv.org/abs/2212.04356) (We support fine-tuning on AiShell-1.) + If you are willing to contribute to icefall, please refer to [contributing](https://icefall.readthedocs.io/en/latest/contributing/index.html) for more details. We would like to highlight the performance of some of the recipes here. diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d16793eb2..073b23713 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -19,7 +19,7 @@ Usage: #fine-tuning with deepspeed zero stage 1 -torchrun --nproc-per-node 8 ./whisper/train.py \ +torchrun --nproc_per_node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_large_v2 \ --model-name large-v2 \ @@ -28,7 +28,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \ --deepspeed_config ./whisper/ds_config_zero1.json # fine-tuning with ddp -torchrun --nproc-per-node 8 ./whisper/train.py \ +torchrun --nproc_per_node 8 ./whisper/train.py \ --max-duration 200 \ --exp-dir whisper/exp_medium \ --manifest-dir data/fbank_whisper \ @@ -136,7 +136,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="whisper/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 777074046d6cb0f9b7ed7a98115299c04b8bcdf1 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Tue, 6 Feb 2024 18:25:43 +0800 Subject: [PATCH 43/46] Fine-tune recipe for Zipformer (#1484) 1. support finetune zipformer 2. update the usage; set a very large batch count --- .../local/compute_fbank_gigaspeech_splits.py | 18 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 17 +- .../ASR/zipformer/decode_gigaspeech.py | 1114 ++++++++++++ egs/librispeech/ASR/zipformer/finetune.py | 1521 +++++++++++++++++ 4 files changed, 2664 insertions(+), 6 deletions(-) create mode 100755 egs/librispeech/ASR/zipformer/decode_gigaspeech.py create mode 100755 egs/librispeech/ASR/zipformer/finetune.py diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 1c71be0f9..176eb8a84 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -51,6 +51,14 @@ def get_parser(): "Determines batch size dynamically.", ) + parser.add_argument( + "--subset", + type=str, + default="XL", + choices=["XL", "L", "M", "S", "XS"], + help="Which subset to work with", + ) + parser.add_argument( "--num-splits", type=int, @@ -76,7 +84,7 @@ def get_parser(): def compute_fbank_gigaspeech_splits(args): num_splits = args.num_splits - output_dir = "data/fbank/XL_split" + output_dir = f"data/fbank/{args.subset}_split" output_dir = Path(output_dir) assert output_dir.exists(), f"{output_dir} does not exist!" @@ -96,15 +104,15 @@ def compute_fbank_gigaspeech_splits(args): logging.info(f"device: {device}") for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) + idx = f"{i}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") - cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" + cuts_path = output_dir / f"cuts_{args.subset}.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" + raw_cuts_path = output_dir / f"cuts_{args.subset}_raw.{idx}.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -113,7 +121,7 @@ def compute_fbank_gigaspeech_splits(args): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/feats_XL_{idx}", + storage_path=f"{output_dir}/feats_{args.subset}_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, overwrite=True, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index dd9e9ef1f..814390ad6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -1,4 +1,4 @@ -# Copyright 2021 Piotr Żelasko +# Copyright 2021 Piotr Żelasko # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -475,3 +475,18 @@ class LibriSpeechAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) + + @lru_cache() + def gigaspeech_subset_small_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech subset-S cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz") + + @lru_cache() + def gigaspeech_dev_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def gigaspeech_test_cuts(self) -> CutSet: + logging.info("About to get Gigaspeech test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/zipformer/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py new file mode 100755 index 000000000..3cda337c0 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/decode_gigaspeech.py @@ -0,0 +1,1114 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() + + dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) + test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py new file mode 100755 index 000000000..843d103cc --- /dev/null +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -0,0 +1,1521 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# Fine-tune without mux (i.e not mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 0 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +# Fine-tune without mux (i.e mixing with original training data): +./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --do-finetune 1 \ + --finetune-ckpt path/to/ckpt \ + --base-lr 0.0045 \ + --use-mux 1 \ + --exp-dir zipformer/exp_finetune \ + --max-duration 1000 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + # Note that we add a very large constant here to make the ScheduledFloat + # variable as their end value. + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + 100000 + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="If true, finetune from a pre-trained checkpoint", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. This is useful + if you want to maintain the performance on the original domain + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.0045, + help="""The base learning rate. + It is set to a very small value as we are doing fine-tuning""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000.0, + help="""Number of steps that affects how rapidly the learning rate + decreases. It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100.0, + help="""Number of epochs that affects how rapidly the learning rate decreases. + It is set to a very large value here to prevent the lr from decaying too fast + during fine-tuning. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + init_modules (list[str]): List of modules to be initialized + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dls: torch.utils.data.DataLoader, + valid_sets: List[str], + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + for valid_set, valid_dl in zip(valid_sets, valid_dls): + logging.info(f"Computing validation loss on {valid_set}") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + assert params.start_epoch == 1, "Fine-tune must start from epoch 1" + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + # Need to update the model_avg if use initialisation + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + else: + # resuming training + assert params.start_epoch > 1, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts() + if params.use_mux: + librispeech_cuts = librispeech.train_all_shuf_cuts() + train_cuts = CutSet.mux( + gigaspeech_cuts, # num cuts = 688182 + librispeech_cuts, # num cuts = 843723 + weights=[688182, 843723], + stop_early=True, + ) + else: + train_cuts = gigaspeech_cuts + logging.info(train_cuts) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() + + valid_sets = ["librispeech", "gigaspeech"] + valid_dls = [ + librispeech.valid_dataloaders(valid_cuts), + librispeech.valid_dataloaders(gigaspeech_dev_cuts), + ] + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dls=valid_dls, + valid_sets=valid_sets, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 4ed88d948494d060a8c2827584d71bff616c3ca9 Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Wed, 7 Feb 2024 10:16:02 +0800 Subject: [PATCH 44/46] Update shared (#1487) There should be one more ../ --- egs/fluent_speech_commands/SLU/shared | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/fluent_speech_commands/SLU/shared b/egs/fluent_speech_commands/SLU/shared index 32a374a7f..9115c7e17 120000 --- a/egs/fluent_speech_commands/SLU/shared +++ b/egs/fluent_speech_commands/SLU/shared @@ -1 +1 @@ -../../icefall/shared/ \ No newline at end of file +../../../icefall/shared/ From 711d6bc46297ba799174bae4f3ebb965224ad76f Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Fri, 9 Feb 2024 10:44:19 +0800 Subject: [PATCH 45/46] Refactor prepare.sh in librispeech (#1493) * Refactor prepare.sh in librispeech, break it into three parts, prepare.sh (basic, minimal requirement for transducer), prepare_lm.sh (ngram & nnlm staff), prepare_mmi.sh (for MMI training). --- egs/librispeech/ASR/RESULTS.md | 2 +- egs/librispeech/ASR/generate-lm.sh | 20 -- egs/librispeech/ASR/local/train_bpe_model.py | 15 + egs/librispeech/ASR/prepare.sh | 334 ++++--------------- egs/librispeech/ASR/prepare_lm.sh | 262 +++++++++++++++ egs/librispeech/ASR/prepare_mmi.sh | 45 +++ 6 files changed, 393 insertions(+), 285 deletions(-) delete mode 100755 egs/librispeech/ASR/generate-lm.sh create mode 100755 egs/librispeech/ASR/prepare_lm.sh create mode 100755 egs/librispeech/ASR/prepare_mmi.sh diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ebf5e89c4..ee5422aba 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1526,7 +1526,7 @@ done You may also decode using LODR + LM shallow fusion. This decoding method is proposed in . It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be -generated by `generate-lm.sh`, or you may download it from . +generated by `prepare_lm.sh` at stage 4, or you may download it from . The decoding command is as follows: diff --git a/egs/librispeech/ASR/generate-lm.sh b/egs/librispeech/ASR/generate-lm.sh deleted file mode 100755 index dacd276d1..000000000 --- a/egs/librispeech/ASR/generate-lm.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env bash - -lang_dir=data/lang_bpe_500 - -for ngram in 2 3 4 5; do - if [ ! -f $lang_dir/${ngram}gram.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order ${ngram} \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/${ngram}gram.arpa - fi - - if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=${ngram} \ - $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt - fi -done diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 43142aee4..5979d5b98 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -28,6 +28,7 @@ import argparse import shutil from pathlib import Path +from typing import Dict import sentencepiece as spm @@ -57,6 +58,18 @@ def get_args(): return parser.parse_args() +def generate_tokens(lang_dir: Path): + """ + Generate the tokens.txt from a bpe model. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f: + for sym, i in token2id.items(): + f.write(f"{sym} {i}\n") + + def main(): args = get_args() vocab_size = args.vocab_size @@ -95,6 +108,8 @@ def main(): shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + generate_tokens(lang_dir) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 4a5072cc0..40dc3260d 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -6,8 +6,21 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail nj=15 -stage=-1 -stop_stage=100 +# run step 0 to step 5 by default +stage=0 +stop_stage=5 + +# Note: This script just prepare the minimal requirements that needed by a +# transducer training with bpe units. +# +# If you want to use ngram or nnlm, please continue running prepare_lm.sh after +# you succeed running this script. +# +# This script also contains the steps to generate phone based units, but they +# will not run automatically, you can generate the phone based units by +# bash prepare.sh --stage -1 --stop-stage -1 +# bash prepare.sh --stage 6 --stop-stage 6 + # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -17,6 +30,18 @@ stop_stage=100 # You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. # You can download them from https://www.openslr.org/12 # +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +# +# lm directory is not necessary for transducer training with bpe units, but it +# is needed by phone based modeling, you can download it by running +# bash prepare.sh --stage -1 --stop-stage -1 +# then you can see the following files in the directory. # - $dl_dir/lm # This directory contains the following files downloaded from # http://www.openslr.org/resources/11 @@ -28,14 +53,7 @@ stop_stage=100 # - librispeech-vocab.txt # - librispeech-lexicon.txt # - librispeech-lm-norm.txt.gz -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech + dl_dir=$PWD/download . shared/parse_options.sh || exit 1 @@ -60,6 +78,8 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "Running prepare.sh" + log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then @@ -159,13 +179,49 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" + log "Stage 5: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare phone based lang" lang_dir=data/lang_phone mkdir -p $lang_dir - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/librispeech-lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt + if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then + log "No lexicon file in $dl_dir/lm, please run :" + log "prepare.sh --stage -1 --stop-stage -1" + exit -1 + fi + + if [ ! -f $lang_dir/lexicon.txt ]; then + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + fi if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang.py --lang-dir $lang_dir @@ -187,253 +243,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then $lang_dir/L_disambig.fst fi fi - - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - files=$( - find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Prepare bigram token-level P for MMI training" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/transcript_tokens.txt ]; then - ./local/convert_transcript_words_to_tokens.py \ - --lexicon $lang_dir/lexicon.txt \ - --transcript $lang_dir/transcript_words.txt \ - --oov "" \ - > $lang_dir/transcript_tokens.txt - fi - - if [ ! -f $lang_dir/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text $lang_dir/transcript_tokens.txt \ - -lm $lang_dir/P.arpa - fi - - if [ ! -f $lang_dir/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="$lang_dir/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - $lang_dir/P.arpa > $lang_dir/P.fst.txt - fi - done -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare G" - # We assume you have installed kaldilm, if not, please install - # it using: pip install kaldilm - - mkdir -p data/lm - if [ ! -f data/lm/G_3_gram.fst.txt ]; then - # It is used in building HLG - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=3 \ - $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt - fi - - if [ ! -f data/lm/G_4_gram.fst.txt ]; then - # It is used for LM rescoring - python3 -m kaldilm \ - --read-symbol-table="data/lang_phone/words.txt" \ - --disambig-symbol='#0' \ - --max-order=4 \ - $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt - fi - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - - if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py \ - --lang-dir $lang_dir \ - --ngram-G ./data/lm/G_3_gram.fst.txt - fi - done -fi - -if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Compile HLG" - ./local/compile_hlg.py --lang-dir data/lang_phone - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_hlg.py --lang-dir $lang_dir - - # Note If ./local/compile_hlg.py throws OOM, - # please switch to the following command - # - # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir - done -fi - -# Compile LG for RNN-T fast_beam_search decoding -if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Compile LG" - ./local/compile_lg.py --lang-dir data/lang_phone - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - ./local/compile_lg.py --lang-dir $lang_dir - done -fi - -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Generate LM training data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - lang_dir=data/lang_bpe_${vocab_size} - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ - --lm-archive $out_dir/lm_data.pt - done -fi - -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Generate LM validation data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/valid.txt ]; then - files=$( - find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $out_dir/valid.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/valid.txt \ - --lm-archive $out_dir/lm_data-valid.pt - done -fi - -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Generate LM test data" - - for vocab_size in ${vocab_sizes[@]}; do - log "Processing vocab_size == ${vocab_size}" - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - - if [ ! -f $out_dir/test.txt ]; then - files=$( - find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt" - find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > $out_dir/test.txt - fi - - lang_dir=data/lang_bpe_${vocab_size} - ./local/prepare_lm_training_data.py \ - --bpe-model $lang_dir/bpe.model \ - --lm-data $out_dir/test.txt \ - --lm-archive $out_dir/lm_data-test.pt - done -fi - -if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then - log "Stage 14: Sort LM training data" - # Sort LM training data by sentence length in descending order - # for ease of training. - # - # Sentence length equals to the number of BPE tokens - # in a sentence. - - for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/lm_training_bpe_${vocab_size} - mkdir -p $out_dir - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data.pt \ - --out-lm-data $out_dir/sorted_lm_data.pt \ - --out-statistics $out_dir/statistics.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-valid.pt \ - --out-lm-data $out_dir/sorted_lm_data-valid.pt \ - --out-statistics $out_dir/statistics-valid.txt - - ./local/sort_lm_training_data.py \ - --in-lm-data $out_dir/lm_data-test.pt \ - --out-lm-data $out_dir/sorted_lm_data-test.pt \ - --out-statistics $out_dir/statistics-test.txt - done -fi diff --git a/egs/librispeech/ASR/prepare_lm.sh b/egs/librispeech/ASR/prepare_lm.sh new file mode 100755 index 000000000..a8eb5ca78 --- /dev/null +++ b/egs/librispeech/ASR/prepare_lm.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +# This script generate Ngram LM / NNLM and related files that needed by decoding. + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# + +. prepare.sh --stage -1 --stop-stage 6 || exit 1 + +log "Running prepare_lm.sh" + +stage=0 +stop_stage=100 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare BPE based lexicon." + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare word level G" + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare token level ngram G" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + for ngram in 2 3 4 5; do + if [ ! -f $lang_dir/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt + fi + done + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate NNLM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate NNLM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Generate NNLM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Sort NNLM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/librispeech/ASR/prepare_mmi.sh b/egs/librispeech/ASR/prepare_mmi.sh new file mode 100755 index 000000000..d8a6e0caf --- /dev/null +++ b/egs/librispeech/ASR/prepare_mmi.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + + +. prepare.sh --stage -1 --stop-stage 6 || exit 1 + +log "Running prepare_mmi.sh" + +stage=0 +stop_stage=100 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done +fi From d9ae8c02a0abdeddc5a4cf9fad72293eda134de3 Mon Sep 17 00:00:00 2001 From: safarisadegh <37085305+safarisadegh@users.noreply.github.com> Date: Fri, 9 Feb 2024 10:35:01 +0330 Subject: [PATCH 46/46] Update README.md (#1497) --- icefall/ctc/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/ctc/README.md b/icefall/ctc/README.md index 0096bc096..1e342f6a3 100644 --- a/icefall/ctc/README.md +++ b/icefall/ctc/README.md @@ -12,6 +12,6 @@ pip install kaldifst kaldi-decoder ``` to install the dependencies. -[kaldi-decoder]: https://github.com/i2-fsa/kaldi-decoder +[kaldi-decoder]: https://github.com/k2-fsa/kaldi-decoder [kaldifst]: https://github.com/k2-fsa/kaldifst [k2]: https://github.com/k2-fsa/k2