diff --git a/egs/mls_english/ASR/README.md b/egs/mls_english/ASR/README.md
new file mode 100644
index 000000000..cb8f51f46
--- /dev/null
+++ b/egs/mls_english/ASR/README.md
@@ -0,0 +1,19 @@
+# Introduction
+
+
+
+**Multilingual LibriSpeech (MLS)** is a large multilingual corpus suitable for speech research. The dataset is derived from read audiobooks from LibriVox and consists of 8 languages - English, German, Dutch, Spanish, French, Italian, Portuguese, Polish. It includes about 44.5K hours of English and a total of about 6K hours for other languages. This icefall training recipe was created for the restructured version of the English split of the dataset available on Hugging Face below.
+
+
+The dataset is available on Hugging Face. For more details, please visit:
+
+- Dataset: https://huggingface.co/datasets/parler-tts/mls_eng
+- Original MLS dataset link: https://www.openslr.org/94
+
+
+## On-the-fly feature computation
+
+This recipe currently only supports on-the-fly feature bank computation, since `lhotse` manifests and feature banks are not pre-calculated in this recipe. This should mean that the dataset can be streamed from Hugging Face, but we have not tested this yet. We may add a version that supports pre-calculating features to better match existing recipes.\
+
+
+[./RESULTS.md](./RESULTS.md) contains the latest results. This MLS English recipe was primarily developed for use in the ```multi_ja_en``` Japanese-English bilingual pipeline, which is based on MLS English and ReazonSpeech.
diff --git a/egs/mls_english/ASR/RESULTS.md b/egs/mls_english/ASR/RESULTS.md
new file mode 100644
index 000000000..5c29fb631
--- /dev/null
+++ b/egs/mls_english/ASR/RESULTS.md
@@ -0,0 +1,41 @@
+## Results
+
+### MLS-English training results (Non-streaming) on zipformer model
+
+#### Non-streaming
+
+**WER on Test Set (Epoch 20)**
+
+| Type | Greedy | Beam search |
+|---------------|--------|-------------|
+| Non-streaming | 6.65 | 6.57 |
+
+
+The training command:
+
+```
+./zipformer/train.py \
+--world-size 8 \
+--num-epochs 20 \
+--start-epoch 9 \
+--use-fp16 1 \
+--exp-dir zipformer/exp \
+--lang-dir data/lang/bpe_2000/
+```
+
+The decoding command:
+
+```
+./zipformer/decode.py \
+ --epoch 20 \
+ --exp-dir ./zipformer/exp \
+ --lang-dir data/lang/bpe_2000/ \
+ --decoding-method greedy_search
+```
+
+
+The pre-trained model is available here : [reazon-research/mls-english
+](https://huggingface.co/reazon-research/mls-english)
+
+
+Please note that this recipe was developed primarily as the source of English input in the bilingual Japanese-English recipe `multi_ja_en`, which uses ReazonSpeech and MLS English.
diff --git a/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py b/egs/mls_english/ASR/local/compute_fbank_mls_english.py
similarity index 60%
rename from egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py
rename to egs/mls_english/ASR/local/compute_fbank_mls_english.py
index af7841406..25ef6c74b 100644
--- a/egs/multi_ja_en/ASR/local/compute_fbank_reazonspeech.py
+++ b/egs/mls_english/ASR/local/compute_fbank_mls_english.py
@@ -33,6 +33,7 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre
RecordingSet,
SupervisionSet,
)
+from lhotse.utils import is_module_available
# fmt: on
@@ -48,55 +49,54 @@ concat_params = {"gap": 1.0, "maxlen": 10.0}
def make_cutset_blueprints(
- manifest_dir: Path,
+ mls_eng_hf_dataset_path: str = "parler-tts/mls_eng",
) -> List[Tuple[str, CutSet]]:
cut_sets = []
+ if not is_module_available("datasets"):
+ raise ImportError(
+ "To process the MLS English HF corpus, please install optional dependency: pip install datasets"
+ )
+
+ from datasets import load_dataset
+
+ print(f"{mls_eng_hf_dataset_path=}")
+ dataset = load_dataset(str(mls_eng_hf_dataset_path))
+
# Create test dataset
logging.info("Creating test cuts.")
cut_sets.append(
(
"test",
- CutSet.from_manifests(
- recordings=RecordingSet.from_file(
- manifest_dir / "reazonspeech_recordings_test.jsonl.gz"
- ),
- supervisions=SupervisionSet.from_file(
- manifest_dir / "reazonspeech_supervisions_test.jsonl.gz"
- ),
- ),
+ CutSet.from_huggingface_dataset(dataset["test"], text_key="transcript"),
)
)
# Create dev dataset
logging.info("Creating dev cuts.")
- cut_sets.append(
- (
- "dev",
- CutSet.from_manifests(
- recordings=RecordingSet.from_file(
- manifest_dir / "reazonspeech_recordings_dev.jsonl.gz"
- ),
- supervisions=SupervisionSet.from_file(
- manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz"
- ),
- ),
+ try:
+ cut_sets.append(
+ (
+ "dev",
+ CutSet.from_huggingface_dataset(dataset["dev"], text_key="transcript"),
+ )
+ )
+ except KeyError:
+ cut_sets.append(
+ (
+ "dev",
+ CutSet.from_huggingface_dataset(
+ dataset["validation"], text_key="transcript"
+ ),
+ )
)
- )
# Create train dataset
logging.info("Creating train cuts.")
cut_sets.append(
(
"train",
- CutSet.from_manifests(
- recordings=RecordingSet.from_file(
- manifest_dir / "reazonspeech_recordings_train.jsonl.gz"
- ),
- supervisions=SupervisionSet.from_file(
- manifest_dir / "reazonspeech_supervisions_train.jsonl.gz"
- ),
- ),
+ CutSet.from_huggingface_dataset(dataset["train"], text_key="transcript"),
)
)
return cut_sets
@@ -107,6 +107,8 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("-m", "--manifest-dir", type=Path)
+ parser.add_argument("-a", "--audio-dir", type=Path)
+ parser.add_argument("-d", "--dl-dir", type=Path)
return parser.parse_args()
@@ -120,26 +122,33 @@ def main():
logging.basicConfig(format=formatter, level=logging.INFO)
- if (args.manifest_dir / ".reazonspeech-fbank.done").exists():
+ if (args.manifest_dir / ".mls-eng-fbank.done").exists():
logging.info(
- "Previous fbank computed for ReazonSpeech found. "
- f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank."
+ "Previous fbank computed for MLS English found. "
+ f"Delete {args.manifest_dir / '.mls-eng-fbank.done'} to allow recomputing fbank."
)
return
else:
- cut_sets = make_cutset_blueprints(args.manifest_dir)
+ mls_eng_hf_dataset_path = args.dl_dir # "/root/datasets/parler-tts--mls_eng"
+ cut_sets = make_cutset_blueprints(mls_eng_hf_dataset_path)
for part, cut_set in cut_sets:
logging.info(f"Processing {part}")
+ cut_set = cut_set.save_audios(
+ num_jobs=num_jobs,
+ storage_path=(args.audio_dir / part).as_posix(),
+ ) # makes new cutset that loads audio from paths to actual audio files
+
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
num_jobs=num_jobs,
storage_path=(args.manifest_dir / f"feats_{part}").as_posix(),
storage_type=LilcomChunkyWriter,
)
- cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz")
- logging.info("All fbank computed for ReazonSpeech.")
- (args.manifest_dir / ".reazonspeech-fbank.done").touch()
+ cut_set.to_file(args.manifest_dir / f"mls_eng_cuts_{part}.jsonl.gz")
+
+ logging.info("All fbank computed for MLS English.")
+ (args.manifest_dir / ".mls-eng-fbank.done").touch()
if __name__ == "__main__":
diff --git a/egs/mls_english/ASR/local/compute_fbank_musan.py b/egs/mls_english/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/mls_english/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/multi_ja_en/ASR/local/display_manifest_statistics.py b/egs/mls_english/ASR/local/display_manifest_statistics.py
similarity index 94%
rename from egs/multi_ja_en/ASR/local/display_manifest_statistics.py
rename to egs/mls_english/ASR/local/display_manifest_statistics.py
index ace1dd73f..b128a08e0 100644
--- a/egs/multi_ja_en/ASR/local/display_manifest_statistics.py
+++ b/egs/mls_english/ASR/local/display_manifest_statistics.py
@@ -45,8 +45,8 @@ def get_parser():
def main():
args = get_parser()
- for part in ["train", "dev"]:
- path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz"
+ for part in ["dev", "test", "train"]:
+ path = args.manifest_dir / f"mls_eng_cuts_{part}.jsonl.gz"
cuts: CutSet = load_manifest(path)
print("\n---------------------------------\n")
diff --git a/egs/mls_english/ASR/local/train_bpe_model.py b/egs/mls_english/ASR/local/train_bpe_model.py
new file mode 100644
index 000000000..59e79be1e
--- /dev/null
+++ b/egs/mls_english/ASR/local/train_bpe_model.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# You can install sentencepiece via:
+#
+# pip install sentencepiece
+#
+# Due to an issue reported in
+# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
+#
+# Please install a version >=0.1.96
+
+import argparse
+import shutil
+from pathlib import Path
+
+import sentencepiece as spm
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ The generated bpe.model is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--byte-fallback",
+ action="store_true",
+ help="""Whether to enable byte_fallback when training bpe.""",
+ )
+
+ parser.add_argument(
+ "--character-coverage",
+ type=float,
+ default=1.0,
+ help="Character coverage in vocabulary.",
+ )
+
+ parser.add_argument(
+ "--transcript",
+ type=str,
+ help="Training transcript.",
+ )
+
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ help="Vocabulary size for BPE training",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ vocab_size = args.vocab_size
+ lang_dir = Path(args.lang_dir)
+
+ model_type = "bpe"
+
+ model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
+ train_text = args.transcript
+ input_sentence_size = 100000000
+
+ user_defined_symbols = ["", ""]
+ unk_id = len(user_defined_symbols)
+ # Note: unk_id is fixed to 2.
+ # If you change it, you should also change other
+ # places that are using it.
+
+ model_file = Path(model_prefix + ".model")
+ if not model_file.is_file():
+ spm.SentencePieceTrainer.train(
+ input=train_text,
+ vocab_size=vocab_size,
+ model_type=model_type,
+ model_prefix=model_prefix,
+ input_sentence_size=input_sentence_size,
+ character_coverage=args.character_coverage,
+ user_defined_symbols=user_defined_symbols,
+ byte_fallback=args.byte_fallback,
+ unk_id=unk_id,
+ bos_id=-1,
+ eos_id=-1,
+ )
+ else:
+ print(f"{model_file} exists - skipping")
+ return
+
+ shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mls_english/ASR/local/utils/asr_datamodule.py b/egs/mls_english/ASR/local/utils/asr_datamodule.py
new file mode 100644
index 000000000..912606bab
--- /dev/null
+++ b/egs/mls_english/ASR/local/utils/asr_datamodule.py
@@ -0,0 +1,365 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import (
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class MLSEnglishHFAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/manifests"),
+ help="Path to directory with train/dev/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=float,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=False,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ cuts_musan: Optional[CutSet] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+
+ transforms = []
+ if cuts_musan is not None:
+ logging.info("Enable MUSAN")
+ transforms.append(
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+ input_transforms = []
+
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> List[CutSet]:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
+ )
diff --git a/egs/mls_english/ASR/local/utils/create_subsets_greedy.py b/egs/mls_english/ASR/local/utils/create_subsets_greedy.py
new file mode 100644
index 000000000..1a7823182
--- /dev/null
+++ b/egs/mls_english/ASR/local/utils/create_subsets_greedy.py
@@ -0,0 +1,341 @@
+import argparse
+import glob
+import os
+import random
+import re
+import sys
+
+from datasets import Audio, DatasetDict, load_dataset
+
+
+def create_subset_by_hours(
+ full_dataset_path,
+ output_base_dir,
+ target_train_hours,
+ target_dev_hours, # New parameter
+ target_test_hours, # New parameter
+ random_seed=42,
+ duration_column_name="audio_duration",
+):
+ random.seed(random_seed)
+
+ output_subset_dir = os.path.join(
+ output_base_dir,
+ f"mls_english_subset_train{int(target_train_hours)}h_dev{int(target_dev_hours)}h_test{int(target_test_hours)}h",
+ )
+ os.makedirs(output_subset_dir, exist_ok=True)
+ output_subset_data_dir = os.path.join(output_subset_dir, "data")
+ os.makedirs(output_subset_data_dir, exist_ok=True)
+
+ print(
+ f"Attempting to load full dataset from '{full_dataset_path}' using load_dataset..."
+ )
+
+ full_data_dir = os.path.join(full_dataset_path, "data")
+ if not os.path.isdir(full_data_dir):
+ print(
+ f"Error: Expected a 'data' subdirectory at '{full_data_dir}' containing parquet files. "
+ "Please ensure 'full_dataset_path' points to the root of your MLS English download "
+ "(e.g., /path/to/mls_english_downloaded_dir) where 'data' is a direct child.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ all_parquet_files = glob.glob(os.path.join(full_data_dir, "*.parquet"))
+ if not all_parquet_files:
+ print(f"Error: No parquet files found in '{full_data_dir}'.", file=sys.stderr)
+ sys.exit(1)
+
+ data_files = {}
+ # Expanded pattern to also detect 'validation' if it's in filenames
+ split_pattern = re.compile(r"^(train|dev|test|validation)-\d{5}-of-\d{5}\.parquet$")
+
+ print(f" Discovering splits from filenames in '{full_data_dir}'...")
+ for fpath in all_parquet_files:
+ fname = os.path.basename(fpath)
+ match = split_pattern.match(fname)
+ if match:
+ split_name = match.group(1)
+ if split_name not in data_files:
+ data_files[split_name] = []
+ data_files[split_name].append(fpath)
+ else:
+ print(
+ f"Warning: Skipping unrecognized parquet file: {fname}", file=sys.stderr
+ )
+
+ if not data_files:
+ print(
+ "Error: No recognized train, dev, test, or validation parquet files found.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ print(f"Found splits and their parquet files: {list(data_files.keys())}")
+
+ try:
+ full_dataset = load_dataset("parquet", data_files=data_files)
+ except Exception as e:
+ print(
+ f"Error loading dataset from '{full_data_dir}' with load_dataset: {e}",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ if not isinstance(full_dataset, DatasetDict):
+ print(
+ "Error: The loaded dataset is not a DatasetDict. Expected a DatasetDict structure.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ # --- Renaming 'validation' split to 'dev' if necessary ---
+ if "validation" in full_dataset:
+ if "dev" in full_dataset:
+ print(
+ "Warning: Both 'dev' and 'validation' splits found in the original dataset. Keeping 'dev' and skipping rename of 'validation'.",
+ file=sys.stderr,
+ )
+ else:
+ print("Renaming 'validation' split to 'dev' for consistent keying.")
+ full_dataset["dev"] = full_dataset.pop("validation")
+ # --- End Renaming ---
+
+ subset_dataset = DatasetDict()
+ total_final_duration_ms = 0
+
+ def get_duration_from_column(example):
+ """Helper to safely get duration from the specified column, in milliseconds."""
+ if duration_column_name in example:
+ return float(example[duration_column_name]) * 1000
+ else:
+ print(
+ f"Warning: Duration column '{duration_column_name}' not found in example. Returning 0.",
+ file=sys.stderr,
+ )
+ return 0
+
+ # --- NEW: Generalized sampling function ---
+ def sample_split_by_hours(split_name, original_split, target_hours):
+ """
+ Samples a dataset split to reach a target number of hours.
+ Returns the sampled Dataset object and its actual duration in milliseconds.
+ """
+ target_duration_ms = target_hours * 3600 * 1000
+ current_duration_ms = 0
+ indices_to_include = []
+
+ if original_split is None or len(original_split) == 0:
+ print(
+ f" Warning: Original '{split_name}' split is empty or not found. Cannot sample.",
+ file=sys.stderr,
+ )
+ return None, 0
+
+ print(
+ f"\n Processing '{split_name}' split to reach approximately {target_hours} hours..."
+ )
+ print(
+ f" Total samples in original '{split_name}' split: {len(original_split)}"
+ )
+
+ all_original_indices = list(range(len(original_split)))
+ random.shuffle(all_original_indices) # Shuffle indices for random sampling
+
+ num_samples_processed = 0
+ for original_idx in all_original_indices:
+ if current_duration_ms >= target_duration_ms and target_hours > 0:
+ print(
+ f" Target {split_name} hours reached ({target_hours}h). Stopping processing."
+ )
+ break
+
+ example = original_split[original_idx]
+ duration_ms = get_duration_from_column(example)
+
+ if duration_ms > 0:
+ indices_to_include.append(original_idx)
+ current_duration_ms += duration_ms
+
+ num_samples_processed += 1
+ if num_samples_processed % 10000 == 0: # Print progress periodically
+ print(
+ f" Processed {num_samples_processed} samples for '{split_name}'. Current duration: {current_duration_ms / (3600*1000):.2f} hours"
+ )
+
+ # If target_hours was 0, but there were samples, we should include none.
+ # Otherwise, select the chosen indices.
+ if target_hours == 0:
+ sampled_split = original_split.select([]) # Select an empty dataset
+ else:
+ sampled_split = original_split.select(
+ sorted(indices_to_include)
+ ) # Sort to preserve order
+
+ # Ensure the 'audio' column is correctly typed as Audio feature before saving
+ if "audio" in sampled_split.features and not isinstance(
+ sampled_split.features["audio"], Audio
+ ):
+ sampling_rate = (
+ sampled_split.features["audio"].sampling_rate
+ if isinstance(sampled_split.features["audio"], Audio)
+ else 16000
+ )
+ new_features = sampled_split.features
+ new_features["audio"] = Audio(sampling_rate=sampling_rate)
+ sampled_split = sampled_split.cast(new_features)
+
+ print(
+ f" Final '{split_name}' split duration: {current_duration_ms / (3600*1000):.2f} hours ({len(sampled_split)} samples)"
+ )
+ return sampled_split, current_duration_ms
+
+ # --- END NEW: Generalized sampling function ---
+
+ # --- Apply sampling for train, dev, and test splits ---
+ splits_to_process = {
+ "train": target_train_hours,
+ "dev": target_dev_hours,
+ "test": target_test_hours,
+ }
+
+ for split_name, target_hours in splits_to_process.items():
+ if split_name in full_dataset:
+ original_split = full_dataset[split_name]
+ sampled_split, actual_duration_ms = sample_split_by_hours(
+ split_name, original_split, target_hours
+ )
+ if sampled_split is not None:
+ subset_dataset[split_name] = sampled_split
+ total_final_duration_ms += actual_duration_ms
+ else:
+ print(
+ f"Warning: '{split_name}' split not found in original dataset. Skipping sampling.",
+ file=sys.stderr,
+ )
+
+ # --- Handle other splits if any, just copy them ---
+ # This loop now excludes 'validation' since it's handled by renaming to 'dev'
+ for split_name in full_dataset.keys():
+ if split_name not in [
+ "train",
+ "dev",
+ "test",
+ "validation",
+ ]: # Ensure 'validation' is not re-copied if not renamed
+ print(f"Copying unrecognized split '{split_name}' directly.")
+ other_split = full_dataset[split_name]
+ subset_dataset[split_name] = other_split
+ other_duration_ms = sum(get_duration_from_column(ex) for ex in other_split)
+ total_final_duration_ms += other_duration_ms
+ print(
+ f" Copied '{split_name}' split: {len(other_split)} samples ({other_duration_ms / (3600*1000):.2f} hours)"
+ )
+
+ final_total_hours = total_final_duration_ms / (3600 * 1000)
+ print(
+ f"\nOverall subset dataset duration (train + dev + test + others): {final_total_hours:.2f} hours"
+ )
+
+ print(
+ f"Saving subset dataset to '{output_subset_dir}' in Parquet format, matching original 'data' structure..."
+ )
+ try:
+ for split_name, ds_split in subset_dataset.items():
+ ds_split.to_parquet(
+ os.path.join(output_subset_data_dir, f"{split_name}.parquet")
+ )
+ print(f" Saved split '{split_name}' to '{output_subset_data_dir}'")
+
+ print(f"Successfully created and saved subset dataset to '{output_subset_dir}'")
+ except Exception as e:
+ print(
+ f"Error saving subset dataset to '{output_subset_dir}': {e}",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Create a smaller subset of a downloaded Hugging Face audio dataset. "
+ "Samples train, dev, and test splits to target durations using pre-existing duration column. "
+ "Ensures 'validation' split is renamed to 'dev'."
+ )
+ parser.add_argument(
+ "--full-dataset-path",
+ type=str,
+ required=True,
+ help="The local path to the already downloaded Hugging Face dataset. "
+ "This should be the root directory containing the 'data' subdirectory "
+ "(e.g., /path/to/mls_english_download).",
+ )
+ parser.add_argument(
+ "--output-base-dir",
+ type=str,
+ required=True,
+ help="The base directory where the new subset dataset(s) will be saved. "
+ "A subdirectory 'mls_english_subset_trainXh_devYh_testZh' will be created within it.",
+ )
+ parser.add_argument(
+ "--target-train-hours",
+ type=float,
+ required=True,
+ help="The approximate total duration of the 'train' split in hours (e.g., 1000 for 1000 hours).",
+ )
+ parser.add_argument(
+ "--target-dev-hours",
+ type=float,
+ default=0.0,
+ help="The approximate total duration of the 'dev' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split.",
+ )
+ parser.add_argument(
+ "--target-test-hours",
+ type=float,
+ default=0.0,
+ help="The approximate total duration of the 'test' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split.",
+ )
+ parser.add_argument(
+ "--random-seed",
+ type=int,
+ default=42,
+ help="Seed for random number generation to ensure reproducibility (default: 42).",
+ )
+ parser.add_argument(
+ "--duration-column-name",
+ type=str,
+ default="audio_duration",
+ help="The name of the column in the dataset that contains the audio duration (assumed to be in seconds). Default: 'audio_duration'.",
+ )
+ args = parser.parse_args()
+
+ create_subset_by_hours(
+ args.full_dataset_path,
+ args.output_base_dir,
+ args.target_train_hours,
+ args.target_dev_hours,
+ args.target_test_hours,
+ args.random_seed,
+ args.duration_column_name,
+ )
+
+ # Simplified load path message for clarity
+ output_subset_full_path_name = f"mls_english_subset_train{int(args.target_train_hours)}h_dev{int(args.target_dev_hours)}h_test{int(args.target_test_hours)}h"
+ output_subset_data_path = os.path.join(
+ args.output_base_dir, output_subset_full_path_name, "data"
+ )
+
+ print(f"\nTo use your new subset dataset, you can load it like this:")
+ print(f"from datasets import load_dataset")
+ print(f"import os, glob")
+ print(f"data_files = {{}}")
+ print(
+ f"for split_name in ['train', 'dev', 'test']: # Or iterate through actual splits created"
+ )
+ print(
+ f" split_path = os.path.join('{output_subset_data_path}', f'{{split_name}}*.parquet')"
+ )
+ print(f" files = glob.glob(split_path)")
+ print(f" if files: data_files[split_name] = files")
+ print(f"subset = load_dataset('parquet', data_files=data_files)")
+ print(f"print(subset)")
diff --git a/egs/mls_english/ASR/local/utils/download_mls_english.py b/egs/mls_english/ASR/local/utils/download_mls_english.py
new file mode 100644
index 000000000..b4d8bd936
--- /dev/null
+++ b/egs/mls_english/ASR/local/utils/download_mls_english.py
@@ -0,0 +1,48 @@
+import argparse
+import os
+import sys
+
+from huggingface_hub import snapshot_download
+
+
+def download_dataset(dl_dir):
+ """
+ Downloads the MLS English dataset from Hugging Face to `$dl_dir/mls_english`.
+ """
+ repo_id = "parler-tts/mls_eng"
+ local_dataset_dir = os.path.join(dl_dir, "mls_english")
+
+ print(f"Attempting to download '{repo_id}' to '{local_dataset_dir}'...")
+
+ # Ensure the parent directory exists
+ os.makedirs(dl_dir, exist_ok=True)
+
+ try:
+ # snapshot_download handles LFS and large files robustly
+ # local_dir_use_symlinks=False is generally safer for datasets,
+ # especially on network file systems or if you intend to move the data
+ snapshot_download(
+ repo_id=repo_id,
+ repo_type="dataset",
+ local_dir=local_dataset_dir,
+ local_dir_use_symlinks=False,
+ )
+ print(f"Successfully downloaded '{repo_id}' to '{local_dataset_dir}'")
+ except Exception as e:
+ print(f"Error downloading dataset '{repo_id}': {e}", file=sys.stderr)
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Download MLS English dataset from Hugging Face."
+ )
+ parser.add_argument(
+ "--dl-dir",
+ type=str,
+ required=True,
+ help="The base directory where the 'mls_english' dataset will be downloaded.",
+ )
+ args = parser.parse_args()
+
+ download_dataset(args.dl_dir)
diff --git a/egs/mls_english/ASR/local/utils/generate_transcript.py b/egs/mls_english/ASR/local/utils/generate_transcript.py
new file mode 100644
index 000000000..bf2ab53de
--- /dev/null
+++ b/egs/mls_english/ASR/local/utils/generate_transcript.py
@@ -0,0 +1,91 @@
+#!/usr/bin/env python3
+# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Optional
+
+from lhotse import CutSet
+from tqdm import tqdm
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ description="Generate transcripts for BPE training from MLS English dataset",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--dataset-path",
+ type=str,
+ default="parler-tts/mls_eng",
+ help="Path to HuggingFace MLS English dataset (name or local path)",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default=Path("data/lang"),
+ help="Directory to store output transcripts",
+ )
+
+ parser.add_argument(
+ "--split",
+ type=str,
+ default="train",
+ help="Dataset split to use for generating transcripts (train/dev/test)",
+ )
+
+ return parser.parse_args()
+
+
+def generate_transcript_from_cuts(cuts: CutSet, output_file: Path) -> None:
+ """Generate transcript text file from Lhotse CutSet."""
+ with open(output_file, "w") as f:
+ for cut in tqdm(cuts, desc="Processing cuts"):
+ for sup in cut.supervisions:
+ f.write(f"{sup.text}\n")
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
+ level=logging.INFO,
+ )
+
+ args.lang_dir.mkdir(parents=True, exist_ok=True)
+ output_file = args.lang_dir / "transcript.txt"
+
+ logging.info(f"Loading {args.split} split from dataset: {args.dataset_path}")
+ try:
+ cuts = CutSet.from_huggingface_dataset(
+ args.dataset_path, split=args.split, text_key="transcript"
+ )
+ except Exception as e:
+ logging.error(f"Failed to load dataset: {e}")
+ raise
+
+ logging.info(f"Generating transcript to {output_file}")
+ generate_transcript_from_cuts(cuts, output_file)
+ logging.info("Transcript generation completed")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_ja_en/ASR/local/validate_manifest.py b/egs/mls_english/ASR/local/validate_manifest.py
similarity index 100%
rename from egs/multi_ja_en/ASR/local/validate_manifest.py
rename to egs/mls_english/ASR/local/validate_manifest.py
diff --git a/egs/mls_english/ASR/prepare.sh b/egs/mls_english/ASR/prepare.sh
new file mode 100755
index 000000000..c9afca976
--- /dev/null
+++ b/egs/mls_english/ASR/prepare.sh
@@ -0,0 +1,143 @@
+#!/usr/bin/env bash
+
+# Prepare script for MLS English ASR recipe in icefall
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# Configuration for BPE tokenizer
+vocab_sizes=(2000) # You can add more sizes like (500 1000 2000) for comparison
+
+# Directory where dataset will be downloaded
+dl_dir=$PWD/download
+
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+mkdir -p data
+mkdir -p data/audio
+mkdir -p data/manifests
+mkdir -p data/lang
+
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "Starting MLS English data preparation"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+log "Stage 0: Download data"
+ # Check if huggingface_hub is installed
+ if ! python -c "import huggingface_hub" &> /dev/null; then
+ log "huggingface_hub Python library not found. Installing it now..."
+ # Using --break-system-packages for Debian/Ubuntu environments where pip install might fail without it
+ python -m pip install huggingface_hub || \
+ python -m pip install huggingface_hub --break-system-packages || { \
+ log "Failed to install huggingface_hub. Please install it manually: pip install huggingface_hub"; \
+ exit 1; \
+ }
+ log "huggingface_hub installed successfully."
+ fi
+
+ # Check if the dataset already exists to avoid re-downloading
+ if [ ! -d "$dl_dir/mls_english" ]; then
+ log "Dataset not found at $dl_dir/mls_english. Starting download..."
+ if ! python ./local/utils/download_mls_english.py --dl-dir "$dl_dir"; then
+ log "Failed to download MLS English dataset via download_mls_english.py"
+ exit 1
+ fi
+ else
+ log "Dataset already exists at $dl_dir/mls_english. Skipping download."
+ fi
+ # If you ha`ve predownloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ] ; then
+ log "Downloading musan."
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Compute MLS English fbank"
+ if [ ! -e data/manifests/.mls_english-validated.done ]; then
+ python local/compute_fbank_mls_english.py \
+ --manifest-dir data/manifests \
+ --audio-dir data/audio \
+ --dl-dir $dl_dir/mls_english
+ # --dl-dir /root/datasets/parler-tts--mls_eng
+ python local/validate_manifest.py --manifest data/manifests/mls_eng_cuts_train.jsonl.gz
+ python local/validate_manifest.py --manifest data/manifests/mls_eng_cuts_dev.jsonl.gz
+ python local/validate_manifest.py --manifest data/manifests/mls_eng_cuts_test.jsonl.gz
+ touch data/manifests/.mls_english-validated.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ if [ ! -e data/manifests/.musan_prep.done ]; then
+ lhotse prepare musan $dl_dir/musan data/manifests
+ touch data/manifests/.musan_prep.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Compute fbank for musan"
+ if [ ! -e data/manifests/.musan_fbank.done ]; then
+ ./local/compute_fbank_musan.py
+ touch data/manifests/.musan_fbank.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare transcript for BPE training"
+ if [ ! -f data/lang/transcript.txt ]; then
+ log "Generating transcripts for BPE training"
+ python local/utils/generate_transcript.py \
+ --dataset-path $dl_dir/mls_english \
+ --lang-dir data/lang \
+ --split train
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Prepare BPE tokenizer"
+ for vocab_size in ${vocab_sizes[@]}; do
+ log "Training BPE model with vocab_size=${vocab_size}"
+ bpe_dir=data/lang/bpe_${vocab_size}
+ mkdir -p $bpe_dir
+
+ if [ ! -f $bpe_dir/bpe.model ]; then
+ python local/train_bpe_model.py \
+ --lang-dir $bpe_dir \
+ --vocab-size $vocab_size \
+ --transcript data/lang/transcript.txt
+ fi
+ done
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Show manifest statistics"
+ python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt
+ cat data/manifests/manifest_statistics.txt
+fi
+
+log "MLS English data preparation completed successfully"
diff --git a/egs/mls_english/ASR/shared b/egs/mls_english/ASR/shared
new file mode 120000
index 000000000..e9461a6d7
--- /dev/null
+++ b/egs/mls_english/ASR/shared
@@ -0,0 +1 @@
+../../librispeech/ASR/shared
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/asr_datamodule.py b/egs/mls_english/ASR/zipformer/asr_datamodule.py
new file mode 120000
index 000000000..a48591198
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1 @@
+../local/utils/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/beam_search.py b/egs/mls_english/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..8e2c0a65c
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/ctc_decode.py b/egs/mls_english/ASR/zipformer/ctc_decode.py
new file mode 120000
index 000000000..faa8bd562
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/ctc_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/ctc_decode.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/decode.py b/egs/mls_english/ASR/zipformer/decode.py
new file mode 100755
index 000000000..220cdcc9d
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/decode.py
@@ -0,0 +1,1085 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import MLSEnglishHFAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+
+# import sentencepiece as spm
+from tokenizer import Tokenizer
+
+# from gigaspeech_scoring import asr_text_post_processing
+from train import add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ # parser.add_argument(
+ # "--bpe-model",
+ # type=str,
+ # default="data/lang_bpe_500/bpe.model",
+ # help="Path to the BPE model",
+ # )
+
+ # parser.add_argument(
+ # "--lang-dir",
+ # type=Path,
+ # default="data/lang_bpe_500",
+ # help="The lang dir containing word table and LG graph",
+ # )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="Path to the lang dir with the BPE model (`bpe.model`)",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+
+ return parser
+
+
+def asr_text_post_processing(inp):
+ return inp
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: Tokenizer,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: Tokenizer,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ MLSEnglishHFAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ # sp = spm.SentencePieceProcessor()
+ # sp.load(params.bpe_model)
+
+ sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append(line.strip())
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(sp.encode(contexts))
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ mls_english_corpus = MLSEnglishHFAsrDataModule(args)
+
+ # # dev_cuts = mls_english_corpus.dev_cuts()
+ # test_cuts = mls_english_corpus.test_cuts()
+
+ # dev_dl = mls_english_corpus.test_dataloader()
+ test_cuts = mls_english_corpus.test_cuts()
+ test_dl = mls_english_corpus.test_dataloaders(test_cuts)
+
+ test_sets = ["test"]
+ test_dls = [test_dl]
+
+ # test_sets = ["dev", "test"]
+ # test_dls = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mls_english/ASR/zipformer/decode_stream.py b/egs/mls_english/ASR/zipformer/decode_stream.py
new file mode 120000
index 000000000..b8d8ddfc4
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decode_stream.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/decoder.py b/egs/mls_english/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/do_not_use_it_directly.py b/egs/mls_english/ASR/zipformer/do_not_use_it_directly.py
new file mode 100755
index 000000000..072679cfc
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/do_not_use_it_directly.py
@@ -0,0 +1,1261 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --lang data/lang_char \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --lang data/lang_char \
+ --max-duration 550
+"""
+
+
+import argparse
+import copy
+import logging
+import math
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import ReazonSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from tokenizer import Tokenizer
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer_for_ncnn_export_only import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+LOG_EPS = math.log(1e-10)
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=Path,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--pad-feature",
+ type=int,
+ default=0,
+ help="""
+ Number of frames to pad at the end.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 1000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ is_pnnx=True,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: Tokenizer,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.pad_feature:
+ feature_lens += params.pad_feature
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, params.pad_feature),
+ value=LOG_EPS,
+ )
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: Tokenizer,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except Exception as e: # noqa
+ logging.error(e, exc_info=True)
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise e
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ log_mode = logging.info
+ log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ log_mode(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, master_port=params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = Tokenizer.load(args.lang, args.lang_type)
+
+ # is defined in local/prepare_lang_char.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 0.3 or c.duration > 30.0:
+ logging.debug(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.info(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
+ train_cuts = reazonspeech_corpus.train_cuts()
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = reazonspeech_corpus.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = reazonspeech_corpus.valid_cuts()
+ valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
+
+ if params.start_batch <= 0 and not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: Tokenizer,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: Tokenizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ raise RuntimeError("Please don't use this file directly!")
+ parser = get_parser()
+ ReazonSpeechAsrDataModule.add_arguments(parser)
+ Tokenizer.add_arguments(parser)
+ args = parser.parse_args()
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mls_english/ASR/zipformer/encoder_interface.py b/egs/mls_english/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..c2eaca671
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/export-onnx.py b/egs/mls_english/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/export.py b/egs/mls_english/ASR/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/generate_averaged_model.py b/egs/mls_english/ASR/zipformer/generate_averaged_model.py
new file mode 120000
index 000000000..5a015ee6c
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/generate_averaged_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/generate_averaged_model.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/joiner.py b/egs/mls_english/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/model.py b/egs/mls_english/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/my_profile.py b/egs/mls_english/ASR/zipformer/my_profile.py
new file mode 120000
index 000000000..3a90b2628
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/my_profile.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/my_profile.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/onnx_pretrained.py b/egs/mls_english/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/optim.py b/egs/mls_english/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/pretrained.py b/egs/mls_english/ASR/zipformer/pretrained.py
new file mode 120000
index 000000000..0bd71dde4
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/pretrained.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/scaling.py b/egs/mls_english/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/scaling_converter.py b/egs/mls_english/ASR/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/streaming_beam_search.py b/egs/mls_english/ASR/zipformer/streaming_beam_search.py
new file mode 120000
index 000000000..b1ed54557
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/streaming_decode.py b/egs/mls_english/ASR/zipformer/streaming_decode.py
new file mode 100755
index 000000000..e8e330481
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/streaming_decode.py
@@ -0,0 +1,900 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengwei Yao)
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
+
+"""
+
+import argparse
+import logging
+import math
+import os
+import pdb
+import subprocess as sp
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import torch
+from asr_datamodule import ReazonSpeechAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from tokenizer import Tokenizer
+from torch import Tensor, nn
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_char",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_init_states(
+ model: nn.Module,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = model.encoder.get_init_states(batch_size, device)
+
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """Stack list of zipformer states that correspond to separate utterances
+ into a single emformer state, so that it can be used as an input for
+ zipformer when those utterances are formed into a batch.
+
+ Args:
+ state_list:
+ Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance. For element-n,
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
+ cached_val2, cached_conv1, cached_conv2).
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Note:
+ It is the inverse of :func:`unstack_states`.
+ """
+ batch_size = len(state_list)
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
+ tot_num_layers = (len(state_list[0]) - 2) // 6
+
+ batch_states = []
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key = torch.cat(
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
+ )
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn = torch.cat(
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1 = torch.cat(
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2 = torch.cat(
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1 = torch.cat(
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2 = torch.cat(
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
+ )
+ batch_states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ cached_embed_left_pad = torch.cat(
+ [state_list[i][-2] for i in range(batch_size)], dim=0
+ )
+ batch_states.append(cached_embed_left_pad)
+
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
+ batch_states.append(processed_lens)
+
+ return batch_states
+
+
+def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
+ """Unstack the zipformer state corresponding to a batch of utterances
+ into a list of states, where the i-th entry is the state from the i-th
+ utterance in the batch.
+
+ Note:
+ It is the inverse of :func:`stack_states`.
+
+ Args:
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ state_list[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Returns:
+ state_list: A list of list. Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance.
+ """
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
+ tot_num_layers = (len(batch_states) - 2) // 6
+
+ processed_lens = batch_states[-1]
+ batch_size = processed_lens.shape[0]
+
+ state_list = [[] for _ in range(batch_size)]
+
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
+ chunks=batch_size, dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
+ chunks=batch_size, dim=0
+ )
+ for i in range(batch_size):
+ state_list[i] += [
+ cached_key_list[i],
+ cached_nonlin_attn_list[i],
+ cached_val1_list[i],
+ cached_val2_list[i],
+ cached_conv1_list[i],
+ cached_conv2_list[i],
+ ]
+
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(cached_embed_left_pad_list[i])
+
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(processed_lens_list[i])
+
+ return state_list
+
+
+def streaming_forward(
+ features: Tensor,
+ feature_lens: Tensor,
+ model: nn.Module,
+ states: List[Tensor],
+ chunk_size: int,
+ left_context_len: int,
+) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ cached_embed_left_pad = states[-2]
+ (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_lens = states[-1] # (batch,)
+ idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand(
+ x.size(0), left_context_len
+ )
+ # True means padding positions (not yet available in cache).
+ processed_mask = idx >= processed_lens.unsqueeze(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = model.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ # pdb.set_trace()
+ # print(model)
+ # print(model.device)
+ # device = model.device
+ chunk_size = int(params.chunk_size)
+ left_context_len = int(params.left_context_frames)
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = [] # Used in fast-beam-search
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(chunk_size * 2)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=model.device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # Make sure the length after encoder_embed is at least 1.
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ tail_length = chunk_size * 2 + 7 + 2 * 3
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+
+ encoder_out, encoder_out_lens, new_states = streaming_forward(
+ features=features,
+ feature_lens=feature_lens,
+ model=model,
+ states=states,
+ chunk_size=chunk_size,
+ left_context_len=left_context_len,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = torch.tensor(processed_lens, device=model.device)
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ # if decode_streams[i].done:
+ # finished_streams.append(i)
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ tokenizer:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 100
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=30)
+ decode_stream.ground_truth = cut.supervisions[0].text
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ tokenizer.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ # print("INSIDE LEN DECODE STREAMS")
+ # pdb.set_trace()
+ # print(model.device)
+ # test_device = model.device
+ # print("done")
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ # print('INSIDE FOR LOOP ')
+ # print(finished_streams)
+
+ if not finished_streams:
+ print("No finished streams, breaking the loop")
+ break
+
+ for i in sorted(finished_streams, reverse=True):
+ try:
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ tokenizer.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+ except IndexError as e:
+ print(f"IndexError: {e}")
+ print(f"decode_streams length: {len(decode_streams)}")
+ print(f"finished_streams: {finished_streams}")
+ print(f"i: {i}")
+ continue
+
+ if params.decoding_method == "greedy_search":
+ key = "greedy_search"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ torch.cuda.synchronize()
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ ReazonSpeechAsrDataModule.add_arguments(parser)
+ Tokenizer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ assert params.causal, params.causal
+ assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp_token = Tokenizer.load(params.lang, params.lang_type)
+
+ # and is defined in local/train_bpe_model.py
+ params.blank_id = sp_token.piece_to_id("")
+ params.unk_id = sp_token.piece_to_id("")
+ params.vocab_size = sp_token.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
+
+ valid_cuts = reazonspeech_corpus.valid_cuts()
+ test_cuts = reazonspeech_corpus.test_cuts()
+
+ test_sets = ["valid", "test"]
+ test_cuts = [valid_cuts, test_cuts]
+
+ for test_set, test_cut in zip(test_sets, test_cuts):
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ tokenizer=sp_token,
+ decoding_graph=decoding_graph,
+ )
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ # valid_cuts = reazonspeech_corpus.valid_cuts()
+
+ # for valid_cut in valid_cuts:
+ # results_dict = decode_dataset(
+ # cuts=valid_cut,
+ # params=params,
+ # model=model,
+ # sp=sp,
+ # decoding_graph=decoding_graph,
+ # )
+ # save_results(
+ # params=params,
+ # test_set_name="valid",
+ # results_dict=results_dict,
+ # )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mls_english/ASR/zipformer/subsampling.py b/egs/mls_english/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/test_scaling.py b/egs/mls_english/ASR/zipformer/test_scaling.py
new file mode 120000
index 000000000..715798436
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/test_scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/test_scaling.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/test_subsampling.py b/egs/mls_english/ASR/zipformer/test_subsampling.py
new file mode 120000
index 000000000..bf0ee3d11
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/test_subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/test_subsampling.py
\ No newline at end of file
diff --git a/egs/mls_english/ASR/zipformer/tokenizer.py b/egs/mls_english/ASR/zipformer/tokenizer.py
new file mode 100644
index 000000000..ba71cff89
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/tokenizer.py
@@ -0,0 +1,252 @@
+import argparse
+from pathlib import Path
+from typing import Callable, List, Union
+
+import sentencepiece as spm
+from k2 import SymbolTable
+
+
+class Tokenizer:
+ text2word: Callable[[str], List[str]]
+
+ @staticmethod
+ def add_arguments(parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(title="Lang related options")
+ group.add_argument("--lang", type=Path, help="Path to lang directory.")
+
+ group.add_argument(
+ "--lang-type",
+ type=str,
+ default=None,
+ help=(
+ "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
+ "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
+ ),
+ )
+
+ @staticmethod
+ def Load(lang_dir: Path, lang_type="", oov=""):
+
+ if not lang_type:
+ assert (lang_dir / "lang_type").exists(), "lang_type not specified."
+ lang_type = (lang_dir / "lang_type").read_text().strip()
+
+ tokenizer = None
+
+ if lang_type == "bpe":
+ assert (
+ lang_dir / "bpe.model"
+ ).exists(), f"No BPE .model could be found in {lang_dir}."
+ tokenizer = spm.SentencePieceProcessor()
+ tokenizer.Load(str(lang_dir / "bpe.model"))
+ elif lang_type == "char":
+ tokenizer = CharTokenizer(lang_dir, oov=oov)
+ else:
+ raise NotImplementedError(f"{lang_type} not supported at the moment.")
+
+ return tokenizer
+
+ load = Load
+
+ def PieceToId(self, piece: str) -> int:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ piece_to_id = PieceToId
+
+ def IdToPiece(self, id: int) -> str:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ id_to_piece = IdToPiece
+
+ def GetPieceSize(self) -> int:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ get_piece_size = GetPieceSize
+
+ def __len__(self) -> int:
+ return self.get_piece_size()
+
+ def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ def EncodeAsIds(self, input: str) -> List[int]:
+ return self.EncodeAsIdsBatch([input])[0]
+
+ def EncodeAsPieces(self, input: str) -> List[str]:
+ return self.EncodeAsPiecesBatch([input])[0]
+
+ def Encode(
+ self, input: Union[str, List[str]], out_type=int
+ ) -> Union[List, List[List]]:
+ if not input:
+ return []
+
+ if isinstance(input, list):
+ if out_type is int:
+ return self.EncodeAsIdsBatch(input)
+ if out_type is str:
+ return self.EncodeAsPiecesBatch(input)
+
+ if out_type is int:
+ return self.EncodeAsIds(input)
+ if out_type is str:
+ return self.EncodeAsPieces(input)
+
+ encode = Encode
+
+ def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ def DecodeIds(self, input: List[int]) -> str:
+ return self.DecodeIdsBatch([input])[0]
+
+ def DecodePieces(self, input: List[str]) -> str:
+ return self.DecodePiecesBatch([input])[0]
+
+ def Decode(
+ self,
+ input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
+ ) -> Union[List[str], str]:
+
+ if not input:
+ return ""
+
+ if isinstance(input, int):
+ return self.id_to_piece(input)
+ elif isinstance(input, str):
+ raise TypeError(
+ "Unlike spm.SentencePieceProcessor, cannot decode from type str."
+ )
+
+ if isinstance(input[0], list):
+ if not input[0] or isinstance(input[0][0], int):
+ return self.DecodeIdsBatch(input)
+
+ if isinstance(input[0][0], str):
+ return self.DecodePiecesBatch(input)
+
+ if isinstance(input[0], int):
+ return self.DecodeIds(input)
+ if isinstance(input[0], str):
+ return self.DecodePieces(input)
+
+ raise RuntimeError("Unknown input type")
+
+ decode = Decode
+
+ def SplitBatch(self, input: List[str]) -> List[List[str]]:
+ raise NotImplementedError(
+ "You need to implement this function in the child class."
+ )
+
+ def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
+ if isinstance(input, list):
+ return self.SplitBatch(input)
+ elif isinstance(input, str):
+ return self.SplitBatch([input])[0]
+ raise RuntimeError("Unknown input type")
+
+ split = Split
+
+
+class CharTokenizer(Tokenizer):
+ def __init__(self, lang_dir: Path, oov="", sep=""):
+ assert (
+ lang_dir / "tokens.txt"
+ ).exists(), f"tokens.txt could not be found in {lang_dir}."
+ token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
+ assert (
+ "#0" not in token_table
+ ), "This tokenizer does not support disambig symbols."
+ self._id2sym = token_table._id2sym
+ self._sym2id = token_table._sym2id
+ self.oov = oov
+ self.oov_id = self._sym2id[oov]
+ self.sep = sep
+ if self.sep:
+ self.text2word = lambda x: x.split(self.sep)
+ else:
+ self.text2word = lambda x: list(x.replace(" ", ""))
+
+ def piece_to_id(self, piece: str) -> int:
+ try:
+ return self._sym2id[piece]
+ except KeyError:
+ return self.oov_id
+
+ def id_to_piece(self, id: int) -> str:
+ return self._id2sym[id]
+
+ def get_piece_size(self) -> int:
+ return len(self._sym2id)
+
+ def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
+ return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
+
+ def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
+ return [
+ [i if i in self._sym2id else self.oov for i in self.text2word(text)]
+ for text in input
+ ]
+
+ def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
+ return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
+
+ def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
+ return [self.sep.join(text) for text in input]
+
+ def SplitBatch(self, input: List[str]) -> List[List[str]]:
+ return [self.text2word(text) for text in input]
+
+
+def test_CharTokenizer():
+ test_single_string = "こんにちは"
+ test_multiple_string = [
+ "今日はいい天気ですよね",
+ "諏訪湖は綺麗でしょう",
+ "这在词表外",
+ "分かち 書き に し た 文章 です",
+ "",
+ ]
+ test_empty_string = ""
+ sp = Tokenizer.load(Path("lang_char"), "char", oov="")
+ splitter = sp.split
+ print(sp.encode(test_single_string, out_type=str))
+ print(sp.encode(test_single_string, out_type=int))
+ print(sp.encode(test_multiple_string, out_type=str))
+ print(sp.encode(test_multiple_string, out_type=int))
+ print(sp.encode(test_empty_string, out_type=str))
+ print(sp.encode(test_empty_string, out_type=int))
+ print(sp.decode(sp.encode(test_single_string, out_type=str)))
+ print(sp.decode(sp.encode(test_single_string, out_type=int)))
+ print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
+ print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
+ print(sp.decode(sp.encode(test_empty_string, out_type=str)))
+ print(sp.decode(sp.encode(test_empty_string, out_type=int)))
+ print(splitter(test_single_string))
+ print(splitter(test_multiple_string))
+ print(splitter(test_empty_string))
+
+
+if __name__ == "__main__":
+ test_CharTokenizer()
diff --git a/egs/mls_english/ASR/zipformer/train.py b/egs/mls_english/ASR/zipformer/train.py
new file mode 100755
index 000000000..63020abfb
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/train.py
@@ -0,0 +1,1400 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import MLSEnglishHFAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from lhotse import load_manifest
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from tokenizer import Tokenizer
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ # parser.add_argument(
+ # "--bpe-model",
+ # type=str,
+ # default="data/lang_bpe_500/bpe.model",
+ # help="Path to the BPE model",
+ # )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="Path to the lang dir with the BPE model (`bpe.model`)",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.015, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: Tokenizer,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ losses = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+ simple_loss, pruned_loss, ctc_loss = losses[:3]
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: Tokenizer,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = Tokenizer.load(Path(args.lang_dir), "bpe") # force bpe model
+
+ # is defined in local/prepare_lang_char.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 30.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ mls_english_corpus = MLSEnglishHFAsrDataModule(args)
+ train_cuts = mls_english_corpus.train_cuts()
+ # mls_english_corpus.load_dataset(args.dataset_path)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ if args.enable_musan:
+ musan_path = Path(args.manifest_dir) / "musan_cuts.jsonl.gz"
+ if musan_path.exists():
+ cuts_musan = load_manifest(musan_path)
+ logging.info(f"Loaded MUSAN manifest from {musan_path}")
+ else:
+ logging.warning(f"MUSAN manifest not found at {musan_path}, disabling MUSAN augmentation")
+ cuts_musan = None
+ else:
+ cuts_musan = None
+
+ train_dl = mls_english_corpus.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+ valid_cuts = mls_english_corpus.valid_cuts()
+ valid_dl = mls_english_corpus.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: Tokenizer,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: Tokenizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ MLSEnglishHFAsrDataModule.add_arguments(parser)
+ Tokenizer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/mls_english/ASR/zipformer/zipformer.py b/egs/mls_english/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/mls_english/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/multi_ja_en/ASR/README.md b/egs/multi_ja_en/ASR/README.md
index 09964a4ab..5f734f30c 100644
--- a/egs/multi_ja_en/ASR/README.md
+++ b/egs/multi_ja_en/ASR/README.md
@@ -1,17 +1,36 @@
# Introduction
-A bilingual Japanese-English ASR model that utilizes ReazonSpeech, developed by the developers of ReazonSpeech.
+A bilingual Japanese-English ASR model developed by the developers of ReazonSpeech that utilizes ReazonSpeech and the English subset of Multilingual LibriSpeech (MLS English), .
**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio.
+**Multilingual LibriSpeech (MLS)** is a large multilingual corpus suitable for speech research. The dataset is derived from read audiobooks from LibriVox and consists of 8 languages - English, German, Dutch, Spanish, French, Italian, Portuguese, Polish. It includes about 44.5K hours of English and a total of about 6K hours for other languages. This icefall training recipe was created for the restructured version of the English split of the dataset available on Hugging Face from `parler-tts` [here](https://huggingface.co/datasets/parler-tts/mls_eng).
-# Included Training Sets
-1. LibriSpeech (English)
-2. ReazonSpeech (Japanese)
+# Training Sets
+
+1. ReazonSpeech (Japanese)
+2. Multilingual LibriSpeech (English)
|Datset| Number of hours| URL|
|---|---:|---|
-|**TOTAL**|35,960|---|
-|LibriSpeech|960|https://www.openslr.org/12/|
-|ReazonSpeech (all) |35,000|https://huggingface.co/datasets/reazon-research/reazonspeech|
+|**TOTAL**|79,500|---|
+|MLS English|44,500|https://huggingface.co/datasets/parler-tts/mls_eng|
+|ReazonSpeech (all)|35,000|https://huggingface.co/datasets/reazon-research/reazonspeech|
+
+# Usage
+
+This recipe relies on the `mls_english` recipe and the `reazonspeech` recipe.
+
+To be able to use the `multi_ja_en` recipe, you must first run the `prepare.sh` scripts in both the `mls_english` recipe and the `reazonspeech` recipe.
+
+This recipe does not enforce data balance: please ensure that the `mls_english` and `reazonspeech` datasets prepared above are balanced to your liking (you may use the utility script `create_subsets_greedy.py` in the `mls_english` recipe to create a custom-sized MLS English sub-dataset).
+
+Steps for model training:
+
+0. Run `../../mls_english/ASR/prepare.sh` and `../../reazonspeech/ASR/prepare.sh`
+1. Run `./prepare.sh`
+2. Run `update_cutset_paths.py` (we will soon add this to `./prepare.sh`)
+3. Run `zipformer/train.py` (see example arguments inside the file)
+
+
diff --git a/egs/multi_ja_en/ASR/RESULTS.md b/egs/multi_ja_en/ASR/RESULTS.md
index 0f6996013..24dd42a26 100644
--- a/egs/multi_ja_en/ASR/RESULTS.md
+++ b/egs/multi_ja_en/ASR/RESULTS.md
@@ -2,51 +2,163 @@
### Zipformer
-#### Non-streaming
+#### Non-streaming (Byte-Level BPE vocab_size=2000)
+
+Trained on 15k hours of ReazonSpeech (filtered to only audio segments between 8s and 22s) and 15k hours of MLS English.
The training command is:
```shell
./zipformer/train.py \
- --bilingual 1 \
- --world-size 4 \
- --num-epochs 30 \
+ --world-size 8 \
+ --causal 1 \
+ --num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
- --max-duration 600
+ --manifest-dir data/manifests \
+ --enable-musan True
```
The decoding command is:
```shell
./zipformer/decode.py \
- --epoch 28 \
- --avg 15 \
+ --epoch 10 \
+ --avg 1 \
--exp-dir ./zipformer/exp \
- --max-duration 600 \
- --decoding-method greedy_search
+ --decoding-method modified_beam_search \
+ --manifest-dir data/manifests
```
To export the model with onnx:
```shell
-./zipformer/export-onnx.py --tokens data/lang_bbpe_2000/tokens.txt --use-averaged-model 0 --epoch 35 --avg 1 --exp-dir zipformer/exp --num-encoder-layers "2,2,3,4,3,2" --downsampling-factor "1,2,4,8,4,2" --feedforward-dim "512,768,1024,1536,1024,768" --num-heads "4,4,4,8,4,4" --encoder-dim "192,256,384,512,384,256" --query-head-dim 32 --value-head-dim 12 --pos-head-dim 4 --pos-dim 48 --encoder-unmasked-dim "192,192,256,256,256,192" --cnn-module-kernel "31,31,15,15,15,31" --decoder-dim 512 --joiner-dim 512 --causal False --chunk-size "16,32,64,-1" --left-context-frames "64,128,256,-1" --fp16 True
+./zipformer/export-onnx.py \
+ --tokens ./data/lang/bbpe_2000/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 10 \
+ --avg 1 \
+ --exp-dir ./zipformer/exp
```
+
+WER and CER on test set listed below (calculated with `./zipformer/decode.py`):
+
+| Datasets | ReazonSpeech + MLS English (combined test set) |
+|----------------------|------------------------------------------------|
+| Zipformer WER (%) | test |
+| greedy_search | 6.33 |
+| modified_beam_search | 6.32 |
+
+
+
+We also include WER% for common English ASR datasets:
+
+| Corpus | WER (%) |
+|-----------------------------|---------|
+| CommonVoice | 29.03 |
+| TED | 16.78 |
+| MLS English (test set) | 8.64 |
+
+
+And CER% for common Japanese datasets:
+
+| Corpus | CER (%) |
+|---------------|---------|
+| JSUT | 8.13 |
+| CommonVoice | 9.82 |
+| TEDx | 11.64 |
+
+
+Pre-trained model can be found here: [https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k](https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k)
+
+(Not yet publicly released)
+
+#### Streaming (Byte-Level BPE vocab_size=2000)
+
+Trained on 15k hours of ReazonSpeech (filtered to only audio segments between 8s and 22s) and 15k hours of MLS English.
+
+The training command is:
+
+```shell
+./zipformer/train.py \
+ --world-size 8 \
+ --causal 1 \
+ --num-epochs 10 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --manifest-dir data/manifests \
+ --enable-musan True
+```
+
+The decoding command is:
+
+```shell
+TODO
+```
+
+To export the model with sherpa onnx:
+
+```shell
+./zipformer/export-onnx-streaming.py \
+ --tokens ./data/lang/bbpe_2000/tokens.txt \
+ --use-averaged-model 0 \
+ --epoch 10 \
+ --avg 1 \
+ --exp-dir ./zipformer/exp-15k15k-streaming \
+ --num-encoder-layers "2,2,3,4,3,2" \
+ --downsampling-factor "1,2,4,8,4,2" \
+ --feedforward-dim "512,768,1024,1536,1024,768" \
+ --num-heads "4,4,4,8,4,4" \
+ --encoder-dim "192,256,384,512,384,256" \
+ --query-head-dim 32 \
+ --value-head-dim 12 \
+ --pos-head-dim 4 \
+ --pos-dim 48 \
+ --encoder-unmasked-dim "192,192,256,256,256,192" \
+ --cnn-module-kernel "31,31,15,15,15,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --causal True \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --fp16 True
+```
+
+(Adjust the `chunk-size` and `left-context-frames` as necessary)
+
+To export the model as Torchscript (`.jit`):
+
+```shell
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp-15k15k-streaming \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens data/lang/bbpe_2000/tokens.txt \
+ --epoch 10 \
+ --avg 1 \
+ --jit 1
+```
+
+You may also use decode chunk sizes `16`, `32`, `64`, `128`.
+
Word Error Rates (WERs) listed below:
-| Datasets | ReazonSpeech | ReazonSpeech | LibriSpeech | LibriSpeech |
-|----------------------|--------------|---------------|--------------------|-------------------|
-| Zipformer WER (%) | dev | test | test-clean | test-other |
-| greedy_search | 5.9 | 4.07 | 3.46 | 8.35 |
-| modified_beam_search | 4.87 | 3.61 | 3.28 | 8.07 |
+*Please let us know which script to use to evaluate the streaming model!*
-Character Error Rates (CERs) for Japanese listed below:
-| Decoding Method | In-Distribution CER | JSUT | CommonVoice | TEDx |
-| :------------------: | :-----------------: | :--: | :---------: | :---: |
-| greedy search | 12.56 | 6.93 | 9.75 | 9.67 |
-| modified beam search | 11.59 | 6.97 | 9.55 | 9.51 |
+We also include WER% for common English ASR datasets:
-Pre-trained model can be found here: https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/main
+*Please let us know which script to use to evaluate the streaming model!*
+
+And CER% for common Japanese datasets:
+
+*Please let us know which script to use to evaluate the streaming model!*
+
+
+Pre-trained model can be found here: [https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k](https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k)
+
+(Not yet publicly released)
diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py
index 6134710ad..ad6bd5f40 100755
--- a/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py
+++ b/egs/multi_ja_en/ASR/local/prepare_lang_bbpe.py
@@ -21,7 +21,7 @@
This script takes as input `lang_dir`, which should contain::
- - lang_dir/bbpe.model,
+ - lang_dir/bbpe_2000/bbpe.model
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
@@ -173,7 +173,8 @@ def get_args():
"--lang-dir",
type=str,
help="""Input and output directory.
- It should contain the bpe.model and words.txt
+ It should contain the words.txt file and the
+ bbpe model in a subdirectory (e.g., bbpe_2000/bbpe.model).
""",
)
@@ -184,6 +185,13 @@ def get_args():
help="The out of vocabulary word in lexicon.",
)
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ default=2000, # Add a default value for vocab_size for consistency
+ help="Vocabulary size used for BPE training (determines the bbpe model directory).",
+ )
+
parser.add_argument(
"--debug",
type=str2bool,
@@ -206,6 +214,9 @@ def main():
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bbpe.model"
+ if not model_file.is_file():
+ raise FileNotFoundError(f"BPE model not found at: {model_file}")
+
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
@@ -216,7 +227,7 @@ def main():
if w in words:
words.remove(w)
- lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
+ lexicon, token_sym_table = generate_lexicon(str(model_file), words, args.oov)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
diff --git a/egs/multi_ja_en/ASR/local/prepare_lang_char.py b/egs/multi_ja_en/ASR/local/prepare_lang_char.py
deleted file mode 100644
index 19c5f4a31..000000000
--- a/egs/multi_ja_en/ASR/local/prepare_lang_char.py
+++ /dev/null
@@ -1,75 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import argparse
-import logging
-from pathlib import Path
-
-from lhotse import CutSet
-
-
-def get_args():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- parser.add_argument(
- "train_cut", metavar="train-cut", type=Path, help="Path to the train cut"
- )
-
- parser.add_argument(
- "--lang-dir",
- type=Path,
- default=Path("data/lang_char"),
- help=(
- "Name of lang dir. "
- "If not set, this will default to lang_char_{trans-mode}"
- ),
- )
-
- return parser.parse_args()
-
-
-def main():
- args = get_args()
- logging.basicConfig(
- format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
- level=logging.INFO,
- )
-
- sysdef_string = set(["", "", "", " "])
-
- token_set = set()
- logging.info(f"Creating vocabulary from {args.train_cut}.")
- train_cut: CutSet = CutSet.from_file(args.train_cut)
- for cut in train_cut:
- for sup in cut.supervisions:
- token_set.update(sup.text)
-
- token_set = [""] + sorted(token_set - sysdef_string) + ["", ""]
- args.lang_dir.mkdir(parents=True, exist_ok=True)
- (args.lang_dir / "tokens.txt").write_text(
- "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set))
- )
-
- (args.lang_dir / "lang_type").write_text("char")
- logging.info("Done.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/egs/multi_ja_en/ASR/local/train_bbpe_model.py b/egs/multi_ja_en/ASR/local/train_bbpe_model.py
index d104f2717..b87e6cd28 100755
--- a/egs/multi_ja_en/ASR/local/train_bbpe_model.py
+++ b/egs/multi_ja_en/ASR/local/train_bbpe_model.py
@@ -33,7 +33,7 @@ from pathlib import Path
import sentencepiece as spm
from icefall import byte_encode
-from icefall.utils import tokenize_by_ja_char
+from icefall.utils import str2bool, tokenize_by_ja_char
def get_args():
@@ -41,9 +41,7 @@ def get_args():
parser.add_argument(
"--lang-dir",
type=str,
- help="""Input and output directory.
- The generated bpe.model is saved to this directory.
- """,
+ help="""Input directory.""",
)
parser.add_argument(
@@ -58,6 +56,27 @@ def get_args():
help="Vocabulary size for BPE training",
)
+ parser.add_argument(
+ "--output-model",
+ type=str,
+ help="Path to save the trained BPE model.",
+ required=True,
+ )
+
+ parser.add_argument(
+ "--input-sentence-size",
+ type=int,
+ default=1000000, # Added default value
+ help="Maximum number of sentences to load for BPE training.",
+ )
+
+ parser.add_argument(
+ "--shuffle-input-sentence",
+ type=str2bool,
+ default=True, # Added default value
+ help="Whether to shuffle input sentences.",
+ )
+
return parser.parse_args()
@@ -71,17 +90,20 @@ def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
+ output_model = Path(args.output_model)
+ input_sentence_size = args.input_sentence_size
+ shuffle_input_sentence = args.shuffle_input_sentence
model_type = "unigram"
- model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
- model_file = Path(model_prefix + ".model")
- if model_file.is_file():
- print(f"{model_file} exists - skipping")
+ model_prefix = str(output_model.parent / f"{model_type}_{vocab_size}")
+ temp_model_file = Path(model_prefix + ".model")
+
+ if output_model.is_file():
+ print(f"{output_model} exists - skipping")
return
character_coverage = 1.0
- input_sentence_size = 100000000
user_defined_symbols = ["", ""]
unk_id = len(user_defined_symbols)
@@ -100,6 +122,7 @@ def main():
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
+ shuffle_input_sentence=shuffle_input_sentence,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
@@ -107,7 +130,7 @@ def main():
eos_id=-1,
)
- shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
+ shutil.move(str(temp_model_file), str(output_model))
if __name__ == "__main__":
diff --git a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py
index be18e65c1..417eb3325 100644
--- a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py
+++ b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import argparse
import inspect
import logging
@@ -23,6 +22,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
+import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
@@ -34,12 +34,21 @@ from lhotse.dataset import (
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
-class ReazonSpeechAsrDataModule:
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class MultiDatasetAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
@@ -70,7 +79,7 @@ class ReazonSpeechAsrDataModule:
group.add_argument(
"--manifest-dir",
type=Path,
- default=Path("data/fbank"),
+ default=Path("data/manifests"),
help="Path to directory with train/dev/test cuts.",
)
group.add_argument(
@@ -192,6 +201,32 @@ class ReazonSpeechAsrDataModule:
transforms = []
input_transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(
+ self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
+ )
+ transforms.append(
+ CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
@@ -250,6 +285,8 @@ class ReazonSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
@@ -265,12 +302,17 @@ class ReazonSpeechAsrDataModule:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
+ seed = 42
+ worker_init_fn = _SeedWorkers(seed)
+
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
+ pin_memory=True,
num_workers=self.args.num_workers,
- persistent_workers=False,
+ persistent_workers=True,
+ worker_init_fn=worker_init_fn,
)
return train_dl
@@ -332,24 +374,3 @@ class ReazonSpeechAsrDataModule:
num_workers=self.args.num_workers,
)
return test_dl
-
- @lru_cache()
- def train_cuts(self) -> CutSet:
- logging.info("About to get train cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
- )
-
- @lru_cache()
- def valid_cuts(self) -> CutSet:
- logging.info("About to get dev cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
- )
-
- @lru_cache()
- def test_cuts(self) -> List[CutSet]:
- logging.info("About to get test cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
- )
diff --git a/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py b/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py
new file mode 100644
index 000000000..af0da4364
--- /dev/null
+++ b/egs/multi_ja_en/ASR/local/utils/update_cutset_paths.py
@@ -0,0 +1,156 @@
+import logging
+import os # Import os module to handle symlinks
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
+ """
+ Updates the storage_path in a CutSet's features to reflect the new dataset-specific
+ feature directory structure.
+
+ Args:
+ cuts: The Lhotse CutSet to modify.
+ dataset_name: The name of the dataset (e.g., "reazonspeech", "mls_english")
+ which corresponds to the new subdirectory for features.
+ old_feature_prefix: The prefix that the original feature paths were relative to.
+ This typically corresponds to the root of the manifests dir
+ in the original recipe.
+ """
+ updated_cuts = []
+ for cut in cuts:
+ if cut.features is not None:
+ original_storage_path = Path(cut.features.storage_path)
+ try:
+ relative_path = original_storage_path.relative_to(old_feature_prefix)
+ except ValueError:
+ # If for some reason the path doesn't start with old_feature_prefix,
+ # keep it as is. This can happen if some paths are already absolute or different.
+ logger.warning(
+ f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
+ )
+ updated_cuts.append(cut)
+ continue
+
+ # Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
+ # Construct the new path: data/manifests//feats_train/feats-12.lca
+ if relative_path.parts[0] == dataset_name:
+ new_storage_path = Path("data/manifests") / relative_path
+ else:
+ new_storage_path = Path("data/manifests") / dataset_name / relative_path
+
+ logger.info(
+ f"Updating cut {cut.id}: {original_storage_path} → {new_storage_path}"
+ )
+ new_storage_path.as_posix()
+ updated_cuts.append(cut)
+ else:
+ logger.warning(f"Skipping update for cut {cut.id}: has no features.")
+ updated_cuts.append(cut) # No features, or not a path we need to modify
+
+ return CutSet.from_cuts(updated_cuts)
+
+
+if __name__ == "__main__":
+ # The root where the symlinked manifests are located in the multi_ja_en recipe
+ multi_recipe_manifests_root = Path("data/manifests")
+
+ # Define the datasets and their *specific* manifest file prefixes
+ dataset_manifest_prefixes = {
+ "reazonspeech": "reazonspeech_cuts",
+ "mls_english": "mls_eng_cuts",
+ }
+
+ splits = ["train", "dev", "test"]
+
+ # This is the path segment *inside* the original recipe's data/manifests
+ # that your features were stored under.
+ # e.g., if original path was /original/recipe/data/manifests/feats_train/file.lca
+ # then this is 'data/manifests'
+ original_feature_base_path = "data/manifests"
+
+ musan_manifest_path = multi_recipe_manifests_root / "musan" / "musan_cuts.jsonl.gz"
+ if musan_manifest_path.exists():
+ logger.info(f"Processing musan manifest: {musan_manifest_path}")
+ try:
+ musan_cuts = load_manifest(musan_manifest_path)
+ updated_musan_cuts = update_paths(
+ musan_cuts, "musan", old_feature_prefix="data/fbank"
+ )
+ # Make sure we're overwriting the correct path even if it's a symlink
+ if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
+ logger.info(
+ f"Overwriting existing musan manifest at: {musan_manifest_path}"
+ )
+ os.unlink(musan_manifest_path)
+ updated_musan_cuts.to_file(musan_manifest_path)
+ logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
+
+ except Exception as e:
+ logger.error(
+ f"Error processing musan manifest {musan_manifest_path}: {e}",
+ exc_info=True,
+ )
+ else:
+ logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
+
+ for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
+ dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
+ if not dataset_symlink_dir.is_dir():
+ logger.warning(
+ f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
+ )
+ continue
+
+ for split in splits:
+ # Construct the path to the symlinked manifest file
+ manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
+ symlink_path = (
+ dataset_symlink_dir / manifest_filename
+ ) # This is the path to the symlink itself
+
+ if symlink_path.is_symlink(): # Check if it's actually a symlink
+ # Get the actual path to the target file that the symlink points to
+ # Lhotse's load_manifest will follow this symlink automatically.
+ target_path = os.path.realpath(symlink_path)
+ logger.info(
+ f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
+ )
+ elif symlink_path.is_file(): # If it's a regular file (not a symlink)
+ logger.info(f"Processing regular file: {symlink_path}")
+ target_path = symlink_path # Use its own path as target
+ else:
+ logger.warning(
+ f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
+ )
+ continue # Skip to next iteration
+
+ try:
+ # Load the manifest. Lhotse will resolve the symlink internally for reading.
+ cuts = load_manifest(
+ symlink_path
+ ) # Use symlink_path here, Lhotse handles resolution for loading
+
+ # Update the storage_path within the loaded cuts (in memory)
+ updated_cuts = update_paths(
+ cuts, dataset_name, old_feature_prefix=original_feature_base_path
+ )
+
+ # --- CRITICAL CHANGE HERE ---
+ # Save the *modified* CutSet to the path of the symlink *itself*.
+ # This will overwrite the symlink with the new file, effectively
+ # breaking the symlink and creating a new file in its place.
+ os.unlink(symlink_path)
+ updated_cuts.to_file(symlink_path)
+ logger.info(
+ f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)
+
+ logger.info("CutSet path updating complete.")
diff --git a/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py
index 721bb48e7..f17e1cc6d 120000
--- a/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py
+++ b/egs/multi_ja_en/ASR/local/validate_bpe_lexicon.py
@@ -1 +1 @@
-../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
+/root/Github/reazon-icefall/egs/librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/multi_ja_en/ASR/prepare.sh b/egs/multi_ja_en/ASR/prepare.sh
index 7a6a63418..495b3a116 100755
--- a/egs/multi_ja_en/ASR/prepare.sh
+++ b/egs/multi_ja_en/ASR/prepare.sh
@@ -19,6 +19,8 @@ vocab_sizes=(
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
+mkdir -p data/lang
+mkdir -p data/manifests
log() {
# This function is from espnet
@@ -31,55 +33,54 @@ log "dl_dir: $dl_dir"
log "Dataset: musan"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Soft link fbank of musan"
- mkdir -p data/fbank
if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then
- cd data/fbank
- ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) .
- ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) .
- cd ../..
+ cd data/manifests
+ mkdir -p musan
+ cd musan
+ ln -svfr $(realpath ../../../../../librispeech/ASR/data/fbank/musan_feats) .
+ ln -svfr $(realpath ../../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) .
+ cd ../../..
else
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4"
exit 1
fi
fi
-log "Dataset: LibriSpeech"
+log "Dataset: MLS English"
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
- log "Stage 1: Soft link fbank of LibriSpeech"
- mkdir -p data/fbank
- if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
- cd data/fbank
- ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) .
- ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) .
- cd ../..
+ log "Stage 2: Soft link manifests (including fbank) of MLS English"
+ if [ -e ../../mls_english/ASR/data/manifests/.mls_english-validated.done ]; then
+ cd data/manifests
+ mkdir -p mls_english
+ cd mls_english
+ ln -svfr $(realpath ../../../../../mls_english/ASR/data/manifests/mls_eng_cuts*) .
+ ln -svfr $(realpath ../../../../../mls_english/ASR/data/manifests/feats*) .
+ cd ../../..
else
- log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 1 --stop-stage 1 and ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"
+ log "Abort! Please run ../../mls_english/ASR/prepare.sh --stage 1 --stop-stage 1"
exit 1
fi
fi
log "Dataset: ReazonSpeech"
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
- log "Stage 2: Soft link fbank of ReazonSpeech"
- mkdir -p data/fbank
+ log "Stage 3: Soft link fbank of ReazonSpeech"
if [ -e ../../reazonspeech/ASR/data/manifests/.reazonspeech.done ]; then
- cd data/fbank
- ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/reazonspeech_cuts*) .
- cd ..
- mkdir -p manifests
- cd manifests
- ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/feats_*) .
- cd ../..
+ cd data/manifests
+ mkdir -p reazonspeech
+ cd reazonspeech
+ ln -svfr $(realpath ../../../../../reazonspeech/ASR/data/manifests/reazonspeech_cuts*) .
+ ln -svfr $(realpath ../../../../../reazonspeech/ASR/data/manifests/feats*) .
+ cd ../../..
else
log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 0 --stop-stage 2"
exit 1
fi
fi
-# New Stage 3: Prepare char based lang for ReazonSpeech
-if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
lang_char_dir=data/lang_char
- log "Stage 3: Prepare char based lang for ReazonSpeech"
+ log "Stage 4: Prepare char-based lang for ReazonSpeech"
mkdir -p $lang_char_dir
# Prepare text
@@ -89,7 +90,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
| ./local/text2token.py -t "char" > $lang_char_dir/text
fi
- # jp word segmentation for text
+ # Japanese word segmentation
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \
--input-file $lang_char_dir/text \
@@ -106,80 +107,96 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
- python3 ./local/prepare_char.py --lang-dir data/lang_char
+ python3 ./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
-if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
- log "Stage 4: Prepare Byte BPE based lang"
- mkdir -p data/fbank
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Prepare Byte BPE based lang in data/lang"
+ lang_dir=data/lang
+
+ # Check if required char-based lang data exists
if [ ! -d ../../reazonspeech/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 3 --stop-stage 3"
exit 1
fi
- if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
- log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
+ # Check if BPE data from MLS English exists
+ if [ ! -d ../../mls_english/ASR/data/lang/bpe_2000 ] || [ ! -f ../../mls_english/ASR/data/lang/transcript.txt ]; then
+ log "Abort! Please ensure ../../mls_english/ASR/data/lang/bpe_2000 and ../../mls_english/ASR/data/lang/transcript.txt exist."
+ log "Please run ../../mls_english/ASR/prepare.sh --stage 3 --stop-stage 3 if you haven't already."
exit 1
fi
- cd data/
- # if [ ! -d ./lang_char ]; then
- # ln -svf $(realpath ../../../reazonspeech/ASR/data/lang_char) .
- # fi
- if [ ! -d ./lang_bpe_500 ]; then
- ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
- fi
- cd ../
+ # Create the target lang directory if it doesn't exist
+ mkdir -p $lang_dir
+
+ # Combine Japanese char-level text and English BPE transcript
+ cat data/lang_char/text ../../mls_english/ASR/data/lang/transcript.txt \
+ > $lang_dir/text
for vocab_size in ${vocab_sizes[@]}; do
- lang_dir=data/lang_bbpe_${vocab_size}
- mkdir -p $lang_dir
+ bbpe_dir=$lang_dir/bbpe_${vocab_size}
+ mkdir -p $bbpe_dir
- cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
- > $lang_dir/text
-
- if [ ! -f $lang_dir/transcript_chars.txt ]; then
+ if [ ! -f $bbpe_dir/transcript_chars.txt ]; then
./local/prepare_for_bpe_model.py \
- --lang-dir ./$lang_dir \
+ --lang-dir $bbpe_dir \
--text $lang_dir/text
fi
- if [ ! -f $lang_dir/text_words_segmentation ]; then
+ if [ ! -f $bbpe_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \
--input-file ./data/lang_char/text \
- --output-file $lang_dir/text_words_segmentation
-
- cat ./data/lang_bpe_500/transcript_words.txt \
- >> $lang_dir/text_words_segmentation
+ --output-file $bbpe_dir/text_words_segmentation
+ cat ../../mls_english/ASR/data/lang/transcript.txt \
+ >> $bbpe_dir/text_words_segmentation
fi
- cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
- | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt
+ if [ ! -f $bbpe_dir/words_no_ids.txt ]; then
+ cat $bbpe_dir/text_words_segmentation | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' | uniq > $bbpe_dir/words_no_ids.txt
+ fi
- if [ ! -f $lang_dir/words.txt ]; then
+ if [ ! -f $bbpe_dir/words.txt ]; then
python3 ./local/prepare_words.py \
- --input-file $lang_dir/words_no_ids.txt \
- --output-file $lang_dir/words.txt
+ --input-file $bbpe_dir/words_no_ids.txt \
+ --output-file $bbpe_dir/words.txt
fi
- if [ ! -f $lang_dir/bbpe.model ]; then
+ if [ ! -f $bbpe_dir/bbpe.model ]; then
./local/train_bbpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
- --transcript $lang_dir/text
+ --transcript $lang_dir/text \
+ --output-model $bbpe_dir/bbpe.model \
+ --input-sentence-size 2000000 # Example: limit to 2 million sentences
fi
- if [ ! -f $lang_dir/L_disambig.pt ]; then
- ./local/prepare_lang_bbpe.py --lang-dir $lang_dir
+ if [ ! -f $bbpe_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bbpe.py --lang-dir $bbpe_dir --vocab-size $vocab_size
- log "Validating $lang_dir/lexicon.txt"
- ln -svf $(realpath ../../multi_zh_en/ASR/local/validate_bpe_lexicon.py) local/
+ log "Validating $bbpe_dir/lexicon.txt"
+ ln -svfr $(realpath ../../multi_zh_en/ASR/local/validate_bpe_lexicon.py) local/
./local/validate_bpe_lexicon.py \
- --lexicon $lang_dir/lexicon.txt \
- --bpe-model $lang_dir/bbpe.model
+ --lexicon $bbpe_dir/lexicon.txt \
+ --bpe-model $bbpe_dir/bbpe.model
fi
+
+ # Remove top-level files (if they were created)
+ rm -f $lang_dir/lexicon.txt $lang_dir/L_disambig.pt
done
+
+ # Optional symlink
+ if [ -d $lang_dir/bbpe_2000 ] && [ ! -e $lang_dir/bpe_2000 ]; then
+ ln -sr bbpe_2000 $lang_dir/bpe_2000
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Update cutset paths"
+ python local/utils/update_cutset_paths.py
fi
log "prepare.sh: PREPARATION DONE"
diff --git a/egs/multi_ja_en/ASR/zipformer/decode.py b/egs/multi_ja_en/ASR/zipformer/decode.py
index 9acccfcf7..b1fd44493 100755
--- a/egs/multi_ja_en/ASR/zipformer/decode.py
+++ b/egs/multi_ja_en/ASR/zipformer/decode.py
@@ -68,7 +68,7 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
-from asr_datamodule import ReazonSpeechAsrDataModule
+from asr_datamodule import MultiDatasetAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@@ -157,14 +157,14 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
- default="data/lang_bbpe_2000/bbpe.model",
+ default="data/lang/bbpe_2000/bbpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
- default="data/lang_bbpe_2000",
+ default="data/lang/bbpe_2000",
help="The lang dir containing word table and LG graph",
)
@@ -573,7 +573,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
- ReazonSpeechAsrDataModule.add_arguments(parser)
+ MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@@ -748,7 +748,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
- data_module = ReazonSpeechAsrDataModule(args)
+ multidataset_datamodule = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut):
@@ -759,31 +759,42 @@ def main():
)
return T > 0
- test_sets_cuts = multi_dataset.test_cuts()
+ def tokenize_and_encode_text(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = byte_encode(tokenize_by_ja_char(text))
+ c.supervisions[0].text = text
+ return c
- test_sets = test_sets_cuts.keys()
- test_dl = [
- data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
- for cuts_name in test_sets
- ]
+ test_cuts = multi_dataset.test_cuts()
+ test_cuts = test_cuts.filter(remove_short_utt)
+ # test_cuts = test_cuts.map(tokenize_and_encode_text)
- for test_set, test_dl in zip(test_sets, test_dl):
- logging.info(f"Start decoding test set: {test_set}")
+ test_dl = multidataset_datamodule.test_dataloaders(test_cuts)
- results_dict = decode_dataset(
- dl=test_dl,
- params=params,
- model=model,
- sp=sp,
- word_table=word_table,
- decoding_graph=decoding_graph,
- )
+ # test_sets = test_sets_cuts.keys()
+ # test_dl = [
+ # data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
+ # for cuts_name in test_sets
+ # ]
- save_results(
- params=params,
- test_set_name=test_set,
- results_dict=results_dict,
- )
+ # for test_set, test_dl in zip(test_sets, test_dl):
+ logging.info("Start decoding test set") #: {test_set}")
+
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name="test_set",
+ results_dict=results_dict,
+ )
logging.info("Done!")
diff --git a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py
index 072679cfc..32e6380eb 100755
--- a/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py
+++ b/egs/multi_ja_en/ASR/zipformer/do_not_use_it_directly.py
@@ -57,7 +57,7 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from asr_datamodule import ReazonSpeechAsrDataModule
+from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@@ -1085,8 +1085,8 @@ def run(rank, world_size, args):
return True
- reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
- train_cuts = reazonspeech_corpus.train_cuts()
+ multidataset_datamodule = MultiDatasetAsrDataModule(args)
+ train_cuts = multidataset_datamodule.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@@ -1097,12 +1097,12 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
- train_dl = reazonspeech_corpus.train_dataloaders(
+ train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
- valid_cuts = reazonspeech_corpus.valid_cuts()
- valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
+ valid_cuts = multidataset_datamodule.valid_cuts()
+ valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@@ -1242,7 +1242,7 @@ def scan_pessimistic_batches_for_oom(
def main():
raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
- ReazonSpeechAsrDataModule.add_arguments(parser)
+ MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py
index b0cdc1f6a..eb1bd5fae 100644
--- a/egs/multi_ja_en/ASR/zipformer/multi_dataset.py
+++ b/egs/multi_ja_en/ASR/zipformer/multi_dataset.py
@@ -13,36 +13,36 @@ class MultiDataset:
Args:
manifest_dir:
It is expected to contain the following files:
- - reazonspeech_cuts_train.jsonl.gz
- - librispeech_cuts_train-clean-100.jsonl.gz
- - librispeech_cuts_train-clean-360.jsonl.gz
- - librispeech_cuts_train-other-500.jsonl.gz
+ - mls_english/
+ - mls_eng_cuts_train.jsonl.gz
+ - mls_eng_cuts_dev.jsonl.gz
+ - mls_eng_cuts_test.jsonl.gz
+ - reazonspeech/
+ - reazonspeech_cuts_train.jsonl.gz
+ - reazonspeech_cuts_dev.jsonl.gz
+ - reazonspeech_cuts_test.jsonl.gz
"""
- self.fbank_dir = Path(args.manifest_dir)
+ self.manifest_dir = Path(args.manifest_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
- logging.info("Loading Reazonspeech in lazy mode")
- reazonspeech_cuts = load_manifest_lazy(
- self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz"
+ logging.info("Loading Reazonspeech TRAIN set in lazy mode")
+ reazonspeech_train_cuts = load_manifest_lazy(
+ self.manifest_dir / "reazonspeech/reazonspeech_cuts_train.jsonl.gz"
)
- logging.info("Loading LibriSpeech in lazy mode")
- train_clean_100_cuts = self.train_clean_100_cuts()
- train_clean_360_cuts = self.train_clean_360_cuts()
- train_other_500_cuts = self.train_other_500_cuts()
+ logging.info("Loading MLS English TRAIN set in lazy mode")
+ mls_eng_train_cuts = load_manifest_lazy(
+ self.manifest_dir / "mls_english/mls_eng_cuts_train.jsonl.gz"
+ )
return CutSet.mux(
- reazonspeech_cuts,
- train_clean_100_cuts,
- train_clean_360_cuts,
- train_other_500_cuts,
+ reazonspeech_train_cuts,
+ mls_eng_train_cuts,
weights=[
- len(reazonspeech_cuts),
- len(train_clean_100_cuts),
- len(train_clean_360_cuts),
- len(train_other_500_cuts),
+ len(reazonspeech_train_cuts),
+ len(mls_eng_train_cuts),
],
)
@@ -51,93 +51,90 @@ class MultiDataset:
logging.info("Loading Reazonspeech DEV set in lazy mode")
reazonspeech_dev_cuts = load_manifest_lazy(
- self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
+ self.manifest_dir / "reazonspeech/reazonspeech_cuts_dev.jsonl.gz"
)
- logging.info("Loading LibriSpeech DEV set in lazy mode")
- dev_clean_cuts = self.dev_clean_cuts()
- dev_other_cuts = self.dev_other_cuts()
+ logging.info("Loading MLS English DEV set in lazy mode")
+ mls_eng_dev_cuts = load_manifest_lazy(
+ self.manifest_dir / "mls_english/mls_eng_cuts_dev.jsonl.gz"
+ )
return CutSet.mux(
reazonspeech_dev_cuts,
- dev_clean_cuts,
- dev_other_cuts,
+ mls_eng_dev_cuts,
weights=[
len(reazonspeech_dev_cuts),
- len(dev_clean_cuts),
- len(dev_other_cuts),
+ len(mls_eng_dev_cuts),
],
)
- def test_cuts(self) -> Dict[str, CutSet]:
+ def test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
- logging.info("Loading Reazonspeech set in lazy mode")
+ logging.info("Loading Reazonspeech TEST set in lazy mode")
reazonspeech_test_cuts = load_manifest_lazy(
- self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz"
- )
- reazonspeech_dev_cuts = load_manifest_lazy(
- self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
+ self.manifest_dir / "reazonspeech/reazonspeech_cuts_test.jsonl.gz"
)
- logging.info("Loading LibriSpeech set in lazy mode")
- test_clean_cuts = self.test_clean_cuts()
- test_other_cuts = self.test_other_cuts()
-
- test_cuts = {
- "reazonspeech_test": reazonspeech_test_cuts,
- "reazonspeech_dev": reazonspeech_dev_cuts,
- "librispeech_test_clean": test_clean_cuts,
- "librispeech_test_other": test_other_cuts,
- }
-
- return test_cuts
-
- @lru_cache()
- def train_clean_100_cuts(self) -> CutSet:
- logging.info("About to get train-clean-100 cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
+ logging.info("Loading MLS English TEST set in lazy mode")
+ mls_eng_test_cuts = load_manifest_lazy(
+ self.manifest_dir / "mls_english/mls_eng_cuts_test.jsonl.gz"
)
- @lru_cache()
- def train_clean_360_cuts(self) -> CutSet:
- logging.info("About to get train-clean-360 cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
+ return CutSet.mux(
+ reazonspeech_test_cuts,
+ mls_eng_test_cuts,
+ weights=[
+ len(reazonspeech_test_cuts),
+ len(mls_eng_test_cuts),
+ ],
)
- @lru_cache()
- def train_other_500_cuts(self) -> CutSet:
- logging.info("About to get train-other-500 cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz"
- )
+ # @lru_cache()
+ # def train_clean_100_cuts(self) -> CutSet:
+ # logging.info("About to get train-clean-100 cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
+ # )
- @lru_cache()
- def dev_clean_cuts(self) -> CutSet:
- logging.info("About to get dev-clean cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz"
- )
+ # @lru_cache()
+ # def train_clean_360_cuts(self) -> CutSet:
+ # logging.info("About to get train-clean-360 cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
+ # )
- @lru_cache()
- def dev_other_cuts(self) -> CutSet:
- logging.info("About to get dev-other cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz"
- )
+ # @lru_cache()
+ # def train_other_500_cuts(self) -> CutSet:
+ # logging.info("About to get train-other-500 cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
+ # )
- @lru_cache()
- def test_clean_cuts(self) -> CutSet:
- logging.info("About to get test-clean cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz"
- )
+ # @lru_cache()
+ # def dev_clean_cuts(self) -> CutSet:
+ # logging.info("About to get dev-clean cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ # )
- @lru_cache()
- def test_other_cuts(self) -> CutSet:
- logging.info("About to get test-other cuts")
- return load_manifest_lazy(
- self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz"
- )
+ # @lru_cache()
+ # def dev_other_cuts(self) -> CutSet:
+ # logging.info("About to get dev-other cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ # )
+
+ # @lru_cache()
+ # def test_clean_cuts(self) -> CutSet:
+ # logging.info("About to get test-clean cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
+ # )
+
+ # @lru_cache()
+ # def test_other_cuts(self) -> CutSet:
+ # logging.info("About to get test-other cuts")
+ # return load_manifest_lazy(
+ # self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
+ # )
diff --git a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py
index 935f86de1..e1869d784 100755
--- a/egs/multi_ja_en/ASR/zipformer/streaming_decode.py
+++ b/egs/multi_ja_en/ASR/zipformer/streaming_decode.py
@@ -63,7 +63,7 @@ import k2
import numpy as np
import sentencepiece as spm
import torch
-from asr_datamodule import ReazonSpeechAsrDataModule
+from asr_datamodule import MultiDatasetAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@@ -740,7 +740,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
- ReazonSpeechAsrDataModule.add_arguments(parser)
+ MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@@ -887,7 +887,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
- reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
+ multidataset_datamodule = MultiDatasetAsrDataModule(args)
if params.bilingual:
multi_dataset = MultiDataset(args)
@@ -904,8 +904,8 @@ def main():
test_sets = test_sets_cuts.keys()
test_cuts = [test_sets_cuts[k] for k in test_sets]
- valid_cuts = reazonspeech_corpus.valid_cuts()
- test_cuts = reazonspeech_corpus.test_cuts()
+ valid_cuts = multidataset_datamodule.valid_cuts()
+ test_cuts = multidataset_datamodule.test_cuts()
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]
diff --git a/egs/multi_ja_en/ASR/zipformer/train.py b/egs/multi_ja_en/ASR/zipformer/train.py
index bfb037f50..1c14b4aa4 100755
--- a/egs/multi_ja_en/ASR/zipformer/train.py
+++ b/egs/multi_ja_en/ASR/zipformer/train.py
@@ -25,7 +25,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
- --bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@@ -35,7 +34,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For streaming model training:
./zipformer/train.py \
- --bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@@ -50,6 +48,7 @@ It supports training with:
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
"""
+
import argparse
import copy
import logging
@@ -66,7 +65,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from asr_datamodule import ReazonSpeechAsrDataModule
+from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@@ -77,7 +76,6 @@ from multi_dataset import MultiDataset
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
-from tokenizer import Tokenizer
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -269,13 +267,6 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
- parser.add_argument(
- "--bilingual",
- type=str2bool,
- default=False,
- help="Whether the model is bilingual or not. 1 = bilingual.",
- )
-
parser.add_argument(
"--world-size",
type=int,
@@ -333,11 +324,10 @@ def get_parser():
""",
)
- # changed - not used in monolingual streaming
parser.add_argument(
"--bpe-model",
type=str,
- default="data/lang_bbpe_2000/bbpe.model",
+ default="data/lang/bbpe_2000/bbpe.model",
help="Path to the BPE model",
)
@@ -763,11 +753,9 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
-# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
- tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
@@ -803,10 +791,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
- if not params.bilingual:
- y = tokenizer.encode(texts, out_type=int)
- else:
- y = sentencepiece_processor.encode(texts, out_type=int)
+ y = sentencepiece_processor.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
@@ -862,7 +847,6 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
- tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
@@ -876,7 +860,6 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=False,
@@ -900,7 +883,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
- tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
@@ -972,7 +954,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=True,
@@ -993,7 +974,6 @@ def train_one_epoch(
display_and_save_batch(
batch,
params=params,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
)
raise
@@ -1082,7 +1062,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
valid_dl=valid_dl,
world_size=world_size,
@@ -1136,25 +1115,12 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
- # Use lang_dir for further operations
- # tokenizer = Tokenizer.load(args.lang, args.lang_type)
-
- # sentencepiece_processor = spm.SentencePieceProcessor()
- # sentencepiece_processor.load(params.bpe_model)
- tokenizer = None
- sentencepiece_processor = None
+ sentencepiece_processor = spm.SentencePieceProcessor()
+ sentencepiece_processor.load(params.bpe_model)
# is defined in local/prepare_lang_char.py
-
- if not params.bilingual:
- tokenizer = Tokenizer.load(args.lang, args.lang_type)
- params.blank_id = tokenizer.piece_to_id("")
- params.vocab_size = tokenizer.get_piece_size()
- else:
- sentencepiece_processor = spm.SentencePieceProcessor()
- sentencepiece_processor.load(params.bpe_model)
- params.blank_id = sentencepiece_processor.piece_to_id("")
- params.vocab_size = sentencepiece_processor.get_piece_size()
+ params.blank_id = sentencepiece_processor.piece_to_id("")
+ params.vocab_size = sentencepiece_processor.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@@ -1212,27 +1178,24 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
- reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
- if params.bilingual:
- multi_dataset = MultiDataset(args)
- train_cuts = multi_dataset.train_cuts()
- else:
- train_cuts = reazonspeech_corpus.train_cuts()
+ multidataset_datamodule = MultiDatasetAsrDataModule(args)
+
+ multi_dataset = MultiDataset(args)
+
+ train_cuts = multi_dataset.train_cuts()
def remove_short_and_long_utt(c: Cut):
- # Keep only utterances with duration between 1 second and 20 seconds
- #
- # Caution: There is a reason to select 20.0 here. Please see
- # ../local/display_manifest_statistics.py
+
+ # Keep only utterances greater than 1 second
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
- # the threshold
- # if c.duration < 1.0 or c.duration > 30.0:
- # logging.warning(
- # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
- # )
- # return False
+ # the threshold as this is dependent on which datasets you choose
+ if c.duration < 1.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
@@ -1240,18 +1203,13 @@ def run(rank, world_size, args):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
- T = ((c.num_samples - 7) // 2 + 1) // 2
- if not params.bilingual:
- tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
- else:
- tokens = sentencepiece_processor.encode(
- c.supervisions[0].text, out_type=str
- )
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sentencepiece_processor.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
- f"Number of frames (before subsampling): {c.num_samples}. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
@@ -1270,8 +1228,7 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
- if params.bilingual:
- train_cuts = train_cuts.map(tokenize_and_encode_text)
+ train_cuts = train_cuts.map(tokenize_and_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@@ -1280,22 +1237,19 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
- train_dl = reazonspeech_corpus.train_dataloaders(
+ train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
- if params.bilingual:
- valid_cuts = reazonspeech_corpus.valid_cuts()
- else:
- valid_cuts = multi_dataset.dev_cuts()
- valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
+ valid_cuts = multi_dataset.dev_cuts()
+
+ valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
params=params,
)
@@ -1321,7 +1275,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
train_dl=train_dl,
valid_dl=valid_dl,
@@ -1356,7 +1309,6 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
- tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
@@ -1367,10 +1319,8 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
- tokenizer:
- The BPE Tokenizer model.
sentencepiece_processor:
- The BPE SentencePieceProcessor model.
+ The BPE model.
"""
from lhotse.utils import uuid4
@@ -1382,11 +1332,7 @@ def display_and_save_batch(
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
-
- if params.bilingual:
- y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
- else:
- y = tokenizer.encode(supervisions["text"], out_type=int)
+ y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
@@ -1395,7 +1341,6 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
- tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
params: AttributeDict,
):
@@ -1412,7 +1357,6 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=True,
@@ -1431,7 +1375,6 @@ def scan_pessimistic_batches_for_oom(
display_and_save_batch(
batch,
params=params,
- tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
)
raise
@@ -1442,8 +1385,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
- ReazonSpeechAsrDataModule.add_arguments(parser)
- Tokenizer.add_arguments(parser)
+ MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
diff --git a/egs/reazonspeech/ASR/local/compute_fbank_musan.py b/egs/reazonspeech/ASR/local/compute_fbank_musan.py
index ac9d80720..7bd4878ae 100755
--- a/egs/reazonspeech/ASR/local/compute_fbank_musan.py
+++ b/egs/reazonspeech/ASR/local/compute_fbank_musan.py
@@ -94,12 +94,14 @@ def compute_fbank_musan(
logging.info("Extracting features for Musan")
if whisper_fbank:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if device == "cpu":
+ logging.warning("CUDA not available; using WhisperFbank on CPU.")
extractor = WhisperFbank(
- WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
+ WhisperFbankConfig(num_filters=num_mel_bins, device=device)
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
-
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (