Merge d74e2322e0d84f8d78f8594dbb6b8b4dc8b1b563 into 0904e490c5fb424dc5cb4d14ae468e4d32a07dc4

This commit is contained in:
Kinan Martin 2025-11-28 11:47:38 +08:00 committed by GitHub
commit 8beea2fbfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 6792 additions and 463 deletions

View File

@ -0,0 +1,19 @@
# Introduction
**Multilingual LibriSpeech (MLS)** is a large multilingual corpus suitable for speech research. The dataset is derived from read audiobooks from LibriVox and consists of 8 languages - English, German, Dutch, Spanish, French, Italian, Portuguese, Polish. It includes about 44.5K hours of English and a total of about 6K hours for other languages. This icefall training recipe was created for the restructured version of the English split of the dataset available on Hugging Face below.
The dataset is available on Hugging Face. For more details, please visit:
- Dataset: https://huggingface.co/datasets/parler-tts/mls_eng
- Original MLS dataset link: https://www.openslr.org/94
## On-the-fly feature computation
This recipe currently only supports on-the-fly feature bank computation, since `lhotse` manifests and feature banks are not pre-calculated in this recipe. This should mean that the dataset can be streamed from Hugging Face, but we have not tested this yet. We may add a version that supports pre-calculating features to better match existing recipes.\
<br>
[./RESULTS.md](./RESULTS.md) contains the latest results. This MLS English recipe was primarily developed for use in the ```multi_ja_en``` Japanese-English bilingual pipeline, which is based on MLS English and ReazonSpeech.

View File

@ -0,0 +1,41 @@
## Results
### MLS-English training results (Non-streaming) on zipformer model
#### Non-streaming
**WER on Test Set (Epoch 20)**
| Type | Greedy | Beam search |
|---------------|--------|-------------|
| Non-streaming | 6.65 | 6.57 |
The training command:
```
./zipformer/train.py \
--world-size 8 \
--num-epochs 20 \
--start-epoch 9 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--lang-dir data/lang/bpe_2000/
```
The decoding command:
```
./zipformer/decode.py \
--epoch 20 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang/bpe_2000/ \
--decoding-method greedy_search
```
The pre-trained model is available here : [reazon-research/mls-english
](https://huggingface.co/reazon-research/mls-english)
Please note that this recipe was developed primarily as the source of English input in the bilingual Japanese-English recipe `multi_ja_en`, which uses ReazonSpeech and MLS English.

View File

@ -33,6 +33,7 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre
RecordingSet,
SupervisionSet,
)
from lhotse.utils import is_module_available
# fmt: on
@ -48,55 +49,54 @@ concat_params = {"gap": 1.0, "maxlen": 10.0}
def make_cutset_blueprints(
manifest_dir: Path,
mls_eng_hf_dataset_path: str = "parler-tts/mls_eng",
) -> List[Tuple[str, CutSet]]:
cut_sets = []
if not is_module_available("datasets"):
raise ImportError(
"To process the MLS English HF corpus, please install optional dependency: pip install datasets"
)
from datasets import load_dataset
print(f"{mls_eng_hf_dataset_path=}")
dataset = load_dataset(str(mls_eng_hf_dataset_path))
# Create test dataset
logging.info("Creating test cuts.")
cut_sets.append(
(
"test",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_test.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_test.jsonl.gz"
),
),
CutSet.from_huggingface_dataset(dataset["test"], text_key="transcript"),
)
)
# Create dev dataset
logging.info("Creating dev cuts.")
cut_sets.append(
(
"dev",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_dev.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz"
),
),
try:
cut_sets.append(
(
"dev",
CutSet.from_huggingface_dataset(dataset["dev"], text_key="transcript"),
)
)
except KeyError:
cut_sets.append(
(
"dev",
CutSet.from_huggingface_dataset(
dataset["validation"], text_key="transcript"
),
)
)
)
# Create train dataset
logging.info("Creating train cuts.")
cut_sets.append(
(
"train",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_train.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_train.jsonl.gz"
),
),
CutSet.from_huggingface_dataset(dataset["train"], text_key="transcript"),
)
)
return cut_sets
@ -107,6 +107,8 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("-m", "--manifest-dir", type=Path)
parser.add_argument("-a", "--audio-dir", type=Path)
parser.add_argument("-d", "--dl-dir", type=Path)
return parser.parse_args()
@ -120,26 +122,33 @@ def main():
logging.basicConfig(format=formatter, level=logging.INFO)
if (args.manifest_dir / ".reazonspeech-fbank.done").exists():
if (args.manifest_dir / ".mls-eng-fbank.done").exists():
logging.info(
"Previous fbank computed for ReazonSpeech found. "
f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank."
"Previous fbank computed for MLS English found. "
f"Delete {args.manifest_dir / '.mls-eng-fbank.done'} to allow recomputing fbank."
)
return
else:
cut_sets = make_cutset_blueprints(args.manifest_dir)
mls_eng_hf_dataset_path = args.dl_dir # "/root/datasets/parler-tts--mls_eng"
cut_sets = make_cutset_blueprints(mls_eng_hf_dataset_path)
for part, cut_set in cut_sets:
logging.info(f"Processing {part}")
cut_set = cut_set.save_audios(
num_jobs=num_jobs,
storage_path=(args.audio_dir / part).as_posix(),
) # makes new cutset that loads audio from paths to actual audio files
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
num_jobs=num_jobs,
storage_path=(args.manifest_dir / f"feats_{part}").as_posix(),
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz")
logging.info("All fbank computed for ReazonSpeech.")
(args.manifest_dir / ".reazonspeech-fbank.done").touch()
cut_set.to_file(args.manifest_dir / f"mls_eng_cuts_{part}.jsonl.gz")
logging.info("All fbank computed for MLS English.")
(args.manifest_dir / ".mls-eng-fbank.done").touch()
if __name__ == "__main__":

View File

@ -0,0 +1 @@
../../../librispeech/ASR/local/compute_fbank_musan.py

View File

@ -45,8 +45,8 @@ def get_parser():
def main():
args = get_parser()
for part in ["train", "dev"]:
path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz"
for part in ["dev", "test", "train"]:
path = args.manifest_dir / f"mls_eng_cuts_{part}.jsonl.gz"
cuts: CutSet = load_manifest(path)
print("\n---------------------------------\n")

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import shutil
from pathlib import Path
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--byte-fallback",
action="store_true",
help="""Whether to enable byte_fallback when training bpe.""",
)
parser.add_argument(
"--character-coverage",
type=float,
default=1.0,
help="Character coverage in vocabulary.",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "bpe"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = args.transcript
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=args.character_coverage,
user_defined_symbols=user_defined_symbols,
byte_fallback=args.byte_fallback,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
else:
print(f"{model_file} exists - skipping")
return
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,365 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class MLSEnglishHFAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/dev/test cuts.",
)
group.add_argument(
"--max-duration",
type=float,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=False,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
cuts_musan: Optional[CutSet] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,341 @@
import argparse
import glob
import os
import random
import re
import sys
from datasets import Audio, DatasetDict, load_dataset
def create_subset_by_hours(
full_dataset_path,
output_base_dir,
target_train_hours,
target_dev_hours, # New parameter
target_test_hours, # New parameter
random_seed=42,
duration_column_name="audio_duration",
):
random.seed(random_seed)
output_subset_dir = os.path.join(
output_base_dir,
f"mls_english_subset_train{int(target_train_hours)}h_dev{int(target_dev_hours)}h_test{int(target_test_hours)}h",
)
os.makedirs(output_subset_dir, exist_ok=True)
output_subset_data_dir = os.path.join(output_subset_dir, "data")
os.makedirs(output_subset_data_dir, exist_ok=True)
print(
f"Attempting to load full dataset from '{full_dataset_path}' using load_dataset..."
)
full_data_dir = os.path.join(full_dataset_path, "data")
if not os.path.isdir(full_data_dir):
print(
f"Error: Expected a 'data' subdirectory at '{full_data_dir}' containing parquet files. "
"Please ensure 'full_dataset_path' points to the root of your MLS English download "
"(e.g., /path/to/mls_english_downloaded_dir) where 'data' is a direct child.",
file=sys.stderr,
)
sys.exit(1)
all_parquet_files = glob.glob(os.path.join(full_data_dir, "*.parquet"))
if not all_parquet_files:
print(f"Error: No parquet files found in '{full_data_dir}'.", file=sys.stderr)
sys.exit(1)
data_files = {}
# Expanded pattern to also detect 'validation' if it's in filenames
split_pattern = re.compile(r"^(train|dev|test|validation)-\d{5}-of-\d{5}\.parquet$")
print(f" Discovering splits from filenames in '{full_data_dir}'...")
for fpath in all_parquet_files:
fname = os.path.basename(fpath)
match = split_pattern.match(fname)
if match:
split_name = match.group(1)
if split_name not in data_files:
data_files[split_name] = []
data_files[split_name].append(fpath)
else:
print(
f"Warning: Skipping unrecognized parquet file: {fname}", file=sys.stderr
)
if not data_files:
print(
"Error: No recognized train, dev, test, or validation parquet files found.",
file=sys.stderr,
)
sys.exit(1)
print(f"Found splits and their parquet files: {list(data_files.keys())}")
try:
full_dataset = load_dataset("parquet", data_files=data_files)
except Exception as e:
print(
f"Error loading dataset from '{full_data_dir}' with load_dataset: {e}",
file=sys.stderr,
)
sys.exit(1)
if not isinstance(full_dataset, DatasetDict):
print(
"Error: The loaded dataset is not a DatasetDict. Expected a DatasetDict structure.",
file=sys.stderr,
)
sys.exit(1)
# --- Renaming 'validation' split to 'dev' if necessary ---
if "validation" in full_dataset:
if "dev" in full_dataset:
print(
"Warning: Both 'dev' and 'validation' splits found in the original dataset. Keeping 'dev' and skipping rename of 'validation'.",
file=sys.stderr,
)
else:
print("Renaming 'validation' split to 'dev' for consistent keying.")
full_dataset["dev"] = full_dataset.pop("validation")
# --- End Renaming ---
subset_dataset = DatasetDict()
total_final_duration_ms = 0
def get_duration_from_column(example):
"""Helper to safely get duration from the specified column, in milliseconds."""
if duration_column_name in example:
return float(example[duration_column_name]) * 1000
else:
print(
f"Warning: Duration column '{duration_column_name}' not found in example. Returning 0.",
file=sys.stderr,
)
return 0
# --- NEW: Generalized sampling function ---
def sample_split_by_hours(split_name, original_split, target_hours):
"""
Samples a dataset split to reach a target number of hours.
Returns the sampled Dataset object and its actual duration in milliseconds.
"""
target_duration_ms = target_hours * 3600 * 1000
current_duration_ms = 0
indices_to_include = []
if original_split is None or len(original_split) == 0:
print(
f" Warning: Original '{split_name}' split is empty or not found. Cannot sample.",
file=sys.stderr,
)
return None, 0
print(
f"\n Processing '{split_name}' split to reach approximately {target_hours} hours..."
)
print(
f" Total samples in original '{split_name}' split: {len(original_split)}"
)
all_original_indices = list(range(len(original_split)))
random.shuffle(all_original_indices) # Shuffle indices for random sampling
num_samples_processed = 0
for original_idx in all_original_indices:
if current_duration_ms >= target_duration_ms and target_hours > 0:
print(
f" Target {split_name} hours reached ({target_hours}h). Stopping processing."
)
break
example = original_split[original_idx]
duration_ms = get_duration_from_column(example)
if duration_ms > 0:
indices_to_include.append(original_idx)
current_duration_ms += duration_ms
num_samples_processed += 1
if num_samples_processed % 10000 == 0: # Print progress periodically
print(
f" Processed {num_samples_processed} samples for '{split_name}'. Current duration: {current_duration_ms / (3600*1000):.2f} hours"
)
# If target_hours was 0, but there were samples, we should include none.
# Otherwise, select the chosen indices.
if target_hours == 0:
sampled_split = original_split.select([]) # Select an empty dataset
else:
sampled_split = original_split.select(
sorted(indices_to_include)
) # Sort to preserve order
# Ensure the 'audio' column is correctly typed as Audio feature before saving
if "audio" in sampled_split.features and not isinstance(
sampled_split.features["audio"], Audio
):
sampling_rate = (
sampled_split.features["audio"].sampling_rate
if isinstance(sampled_split.features["audio"], Audio)
else 16000
)
new_features = sampled_split.features
new_features["audio"] = Audio(sampling_rate=sampling_rate)
sampled_split = sampled_split.cast(new_features)
print(
f" Final '{split_name}' split duration: {current_duration_ms / (3600*1000):.2f} hours ({len(sampled_split)} samples)"
)
return sampled_split, current_duration_ms
# --- END NEW: Generalized sampling function ---
# --- Apply sampling for train, dev, and test splits ---
splits_to_process = {
"train": target_train_hours,
"dev": target_dev_hours,
"test": target_test_hours,
}
for split_name, target_hours in splits_to_process.items():
if split_name in full_dataset:
original_split = full_dataset[split_name]
sampled_split, actual_duration_ms = sample_split_by_hours(
split_name, original_split, target_hours
)
if sampled_split is not None:
subset_dataset[split_name] = sampled_split
total_final_duration_ms += actual_duration_ms
else:
print(
f"Warning: '{split_name}' split not found in original dataset. Skipping sampling.",
file=sys.stderr,
)
# --- Handle other splits if any, just copy them ---
# This loop now excludes 'validation' since it's handled by renaming to 'dev'
for split_name in full_dataset.keys():
if split_name not in [
"train",
"dev",
"test",
"validation",
]: # Ensure 'validation' is not re-copied if not renamed
print(f"Copying unrecognized split '{split_name}' directly.")
other_split = full_dataset[split_name]
subset_dataset[split_name] = other_split
other_duration_ms = sum(get_duration_from_column(ex) for ex in other_split)
total_final_duration_ms += other_duration_ms
print(
f" Copied '{split_name}' split: {len(other_split)} samples ({other_duration_ms / (3600*1000):.2f} hours)"
)
final_total_hours = total_final_duration_ms / (3600 * 1000)
print(
f"\nOverall subset dataset duration (train + dev + test + others): {final_total_hours:.2f} hours"
)
print(
f"Saving subset dataset to '{output_subset_dir}' in Parquet format, matching original 'data' structure..."
)
try:
for split_name, ds_split in subset_dataset.items():
ds_split.to_parquet(
os.path.join(output_subset_data_dir, f"{split_name}.parquet")
)
print(f" Saved split '{split_name}' to '{output_subset_data_dir}'")
print(f"Successfully created and saved subset dataset to '{output_subset_dir}'")
except Exception as e:
print(
f"Error saving subset dataset to '{output_subset_dir}': {e}",
file=sys.stderr,
)
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Create a smaller subset of a downloaded Hugging Face audio dataset. "
"Samples train, dev, and test splits to target durations using pre-existing duration column. "
"Ensures 'validation' split is renamed to 'dev'."
)
parser.add_argument(
"--full-dataset-path",
type=str,
required=True,
help="The local path to the already downloaded Hugging Face dataset. "
"This should be the root directory containing the 'data' subdirectory "
"(e.g., /path/to/mls_english_download).",
)
parser.add_argument(
"--output-base-dir",
type=str,
required=True,
help="The base directory where the new subset dataset(s) will be saved. "
"A subdirectory 'mls_english_subset_trainXh_devYh_testZh' will be created within it.",
)
parser.add_argument(
"--target-train-hours",
type=float,
required=True,
help="The approximate total duration of the 'train' split in hours (e.g., 1000 for 1000 hours).",
)
parser.add_argument(
"--target-dev-hours",
type=float,
default=0.0,
help="The approximate total duration of the 'dev' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split.",
)
parser.add_argument(
"--target-test-hours",
type=float,
default=0.0,
help="The approximate total duration of the 'test' split in hours (e.g., 10 for 10 hours). Set to 0 to exclude this split.",
)
parser.add_argument(
"--random-seed",
type=int,
default=42,
help="Seed for random number generation to ensure reproducibility (default: 42).",
)
parser.add_argument(
"--duration-column-name",
type=str,
default="audio_duration",
help="The name of the column in the dataset that contains the audio duration (assumed to be in seconds). Default: 'audio_duration'.",
)
args = parser.parse_args()
create_subset_by_hours(
args.full_dataset_path,
args.output_base_dir,
args.target_train_hours,
args.target_dev_hours,
args.target_test_hours,
args.random_seed,
args.duration_column_name,
)
# Simplified load path message for clarity
output_subset_full_path_name = f"mls_english_subset_train{int(args.target_train_hours)}h_dev{int(args.target_dev_hours)}h_test{int(args.target_test_hours)}h"
output_subset_data_path = os.path.join(
args.output_base_dir, output_subset_full_path_name, "data"
)
print(f"\nTo use your new subset dataset, you can load it like this:")
print(f"from datasets import load_dataset")
print(f"import os, glob")
print(f"data_files = {{}}")
print(
f"for split_name in ['train', 'dev', 'test']: # Or iterate through actual splits created"
)
print(
f" split_path = os.path.join('{output_subset_data_path}', f'{{split_name}}*.parquet')"
)
print(f" files = glob.glob(split_path)")
print(f" if files: data_files[split_name] = files")
print(f"subset = load_dataset('parquet', data_files=data_files)")
print(f"print(subset)")

View File

@ -0,0 +1,48 @@
import argparse
import os
import sys
from huggingface_hub import snapshot_download
def download_dataset(dl_dir):
"""
Downloads the MLS English dataset from Hugging Face to `$dl_dir/mls_english`.
"""
repo_id = "parler-tts/mls_eng"
local_dataset_dir = os.path.join(dl_dir, "mls_english")
print(f"Attempting to download '{repo_id}' to '{local_dataset_dir}'...")
# Ensure the parent directory exists
os.makedirs(dl_dir, exist_ok=True)
try:
# snapshot_download handles LFS and large files robustly
# local_dir_use_symlinks=False is generally safer for datasets,
# especially on network file systems or if you intend to move the data
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dataset_dir,
local_dir_use_symlinks=False,
)
print(f"Successfully downloaded '{repo_id}' to '{local_dataset_dir}'")
except Exception as e:
print(f"Error downloading dataset '{repo_id}': {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download MLS English dataset from Hugging Face."
)
parser.add_argument(
"--dl-dir",
type=str,
required=True,
help="The base directory where the 'mls_english' dataset will be downloaded.",
)
args = parser.parse_args()
download_dataset(args.dl_dir)

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python3
# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Optional
from lhotse import CutSet
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(
description="Generate transcripts for BPE training from MLS English dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dataset-path",
type=str,
default="parler-tts/mls_eng",
help="Path to HuggingFace MLS English dataset (name or local path)",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang"),
help="Directory to store output transcripts",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="Dataset split to use for generating transcripts (train/dev/test)",
)
return parser.parse_args()
def generate_transcript_from_cuts(cuts: CutSet, output_file: Path) -> None:
"""Generate transcript text file from Lhotse CutSet."""
with open(output_file, "w") as f:
for cut in tqdm(cuts, desc="Processing cuts"):
for sup in cut.supervisions:
f.write(f"{sup.text}\n")
def main():
args = get_args()
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
)
args.lang_dir.mkdir(parents=True, exist_ok=True)
output_file = args.lang_dir / "transcript.txt"
logging.info(f"Loading {args.split} split from dataset: {args.dataset_path}")
try:
cuts = CutSet.from_huggingface_dataset(
args.dataset_path, split=args.split, text_key="transcript"
)
except Exception as e:
logging.error(f"Failed to load dataset: {e}")
raise
logging.info(f"Generating transcript to {output_file}")
generate_transcript_from_cuts(cuts, output_file)
logging.info("Transcript generation completed")
if __name__ == "__main__":
main()

143
egs/mls_english/ASR/prepare.sh Executable file
View File

@ -0,0 +1,143 @@
#!/usr/bin/env bash
# Prepare script for MLS English ASR recipe in icefall
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
stop_stage=100
# Configuration for BPE tokenizer
vocab_sizes=(2000) # You can add more sizes like (500 1000 2000) for comparison
# Directory where dataset will be downloaded
dl_dir=$PWD/download
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
mkdir -p data
mkdir -p data/audio
mkdir -p data/manifests
mkdir -p data/lang
log() {
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Starting MLS English data preparation"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# Check if huggingface_hub is installed
if ! python -c "import huggingface_hub" &> /dev/null; then
log "huggingface_hub Python library not found. Installing it now..."
# Using --break-system-packages for Debian/Ubuntu environments where pip install might fail without it
python -m pip install huggingface_hub || \
python -m pip install huggingface_hub --break-system-packages || { \
log "Failed to install huggingface_hub. Please install it manually: pip install huggingface_hub"; \
exit 1; \
}
log "huggingface_hub installed successfully."
fi
# Check if the dataset already exists to avoid re-downloading
if [ ! -d "$dl_dir/mls_english" ]; then
log "Dataset not found at $dl_dir/mls_english. Starting download..."
if ! python ./local/utils/download_mls_english.py --dl-dir "$dl_dir"; then
log "Failed to download MLS English dataset via download_mls_english.py"
exit 1
fi
else
log "Dataset already exists at $dl_dir/mls_english. Skipping download."
fi
# If you ha`ve predownloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ] ; then
log "Downloading musan."
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Compute MLS English fbank"
if [ ! -e data/manifests/.mls_english-validated.done ]; then
python local/compute_fbank_mls_english.py \
--manifest-dir data/manifests \
--audio-dir data/audio \
--dl-dir $dl_dir/mls_english
# --dl-dir /root/datasets/parler-tts--mls_eng
python local/validate_manifest.py --manifest data/manifests/mls_eng_cuts_train.jsonl.gz
python local/validate_manifest.py --manifest data/manifests/mls_eng_cuts_dev.jsonl.gz
python local/validate_manifest.py --manifest data/manifests/mls_eng_cuts_test.jsonl.gz
touch data/manifests/.mls_english-validated.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare musan manifest"
# We assume that you have downloaded the musan corpus
# to $dl_dir/musan
if [ ! -e data/manifests/.musan_prep.done ]; then
lhotse prepare musan $dl_dir/musan data/manifests
touch data/manifests/.musan_prep.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute fbank for musan"
if [ ! -e data/manifests/.musan_fbank.done ]; then
./local/compute_fbank_musan.py
touch data/manifests/.musan_fbank.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare transcript for BPE training"
if [ ! -f data/lang/transcript.txt ]; then
log "Generating transcripts for BPE training"
python local/utils/generate_transcript.py \
--dataset-path $dl_dir/mls_english \
--lang-dir data/lang \
--split train
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare BPE tokenizer"
for vocab_size in ${vocab_sizes[@]}; do
log "Training BPE model with vocab_size=${vocab_size}"
bpe_dir=data/lang/bpe_${vocab_size}
mkdir -p $bpe_dir
if [ ! -f $bpe_dir/bpe.model ]; then
python local/train_bpe_model.py \
--lang-dir $bpe_dir \
--vocab-size $vocab_size \
--transcript data/lang/transcript.txt
fi
done
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Show manifest statistics"
python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt
cat data/manifests/manifest_statistics.txt
fi
log "MLS English data preparation completed successfully"

1
egs/mls_english/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../librispeech/ASR/shared

View File

@ -0,0 +1 @@
../local/utils/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/ctc_decode.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/generate_averaged_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/my_profile.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1,900 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
"""
import argparse
import logging
import math
import os
import pdb
import subprocess as sp
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import torch
from asr_datamodule import ReazonSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_lens = states[-1] # (batch,)
idx = torch.arange(left_context_len, device=x.device).unsqueeze(0).expand(
x.size(0), left_context_len
)
# True means padding positions (not yet available in cache).
processed_mask = idx >= processed_lens.unsqueeze(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
# pdb.set_trace()
# print(model)
# print(model.device)
# device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=model.device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done:
# finished_streams.append(i)
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
# - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
tokenizer.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
# print("INSIDE LEN DECODE STREAMS")
# pdb.set_trace()
# print(model.device)
# test_device = model.device
# print("done")
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)
if not finished_streams:
print("No finished streams, breaking the loop")
break
for i in sorted(finished_streams, reverse=True):
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
tokenizer.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
print(f"decode_streams length: {len(decode_streams)}")
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp_token = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp_token.piece_to_id("<blk>")
params.unk_id = sp_token.piece_to_id("<unk>")
params.vocab_size = sp_token.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
tokenizer=sp_token,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
# valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
# params=params,
# model=model,
# sp=sp,
# decoding_graph=decoding_graph,
# )
# save_results(
# params=params,
# test_set_name="valid",
# results_dict=results_dict,
# )
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_subsampling.py

View File

@ -0,0 +1,252 @@
import argparse
from pathlib import Path
from typing import Callable, List, Union
import sentencepiece as spm
from k2 import SymbolTable
class Tokenizer:
text2word: Callable[[str], List[str]]
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Lang related options")
group.add_argument("--lang", type=Path, help="Path to lang directory.")
group.add_argument(
"--lang-type",
type=str,
default=None,
help=(
"Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
"Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
),
)
@staticmethod
def Load(lang_dir: Path, lang_type="", oov="<unk>"):
if not lang_type:
assert (lang_dir / "lang_type").exists(), "lang_type not specified."
lang_type = (lang_dir / "lang_type").read_text().strip()
tokenizer = None
if lang_type == "bpe":
assert (
lang_dir / "bpe.model"
).exists(), f"No BPE .model could be found in {lang_dir}."
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(str(lang_dir / "bpe.model"))
elif lang_type == "char":
tokenizer = CharTokenizer(lang_dir, oov=oov)
else:
raise NotImplementedError(f"{lang_type} not supported at the moment.")
return tokenizer
load = Load
def PieceToId(self, piece: str) -> int:
raise NotImplementedError(
"You need to implement this function in the child class."
)
piece_to_id = PieceToId
def IdToPiece(self, id: int) -> str:
raise NotImplementedError(
"You need to implement this function in the child class."
)
id_to_piece = IdToPiece
def GetPieceSize(self) -> int:
raise NotImplementedError(
"You need to implement this function in the child class."
)
get_piece_size = GetPieceSize
def __len__(self) -> int:
return self.get_piece_size()
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def EncodeAsIds(self, input: str) -> List[int]:
return self.EncodeAsIdsBatch([input])[0]
def EncodeAsPieces(self, input: str) -> List[str]:
return self.EncodeAsPiecesBatch([input])[0]
def Encode(
self, input: Union[str, List[str]], out_type=int
) -> Union[List, List[List]]:
if not input:
return []
if isinstance(input, list):
if out_type is int:
return self.EncodeAsIdsBatch(input)
if out_type is str:
return self.EncodeAsPiecesBatch(input)
if out_type is int:
return self.EncodeAsIds(input)
if out_type is str:
return self.EncodeAsPieces(input)
encode = Encode
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def DecodeIds(self, input: List[int]) -> str:
return self.DecodeIdsBatch([input])[0]
def DecodePieces(self, input: List[str]) -> str:
return self.DecodePiecesBatch([input])[0]
def Decode(
self,
input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
) -> Union[List[str], str]:
if not input:
return ""
if isinstance(input, int):
return self.id_to_piece(input)
elif isinstance(input, str):
raise TypeError(
"Unlike spm.SentencePieceProcessor, cannot decode from type str."
)
if isinstance(input[0], list):
if not input[0] or isinstance(input[0][0], int):
return self.DecodeIdsBatch(input)
if isinstance(input[0][0], str):
return self.DecodePiecesBatch(input)
if isinstance(input[0], int):
return self.DecodeIds(input)
if isinstance(input[0], str):
return self.DecodePieces(input)
raise RuntimeError("Unknown input type")
decode = Decode
def SplitBatch(self, input: List[str]) -> List[List[str]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
if isinstance(input, list):
return self.SplitBatch(input)
elif isinstance(input, str):
return self.SplitBatch([input])[0]
raise RuntimeError("Unknown input type")
split = Split
class CharTokenizer(Tokenizer):
def __init__(self, lang_dir: Path, oov="<unk>", sep=""):
assert (
lang_dir / "tokens.txt"
).exists(), f"tokens.txt could not be found in {lang_dir}."
token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
assert (
"#0" not in token_table
), "This tokenizer does not support disambig symbols."
self._id2sym = token_table._id2sym
self._sym2id = token_table._sym2id
self.oov = oov
self.oov_id = self._sym2id[oov]
self.sep = sep
if self.sep:
self.text2word = lambda x: x.split(self.sep)
else:
self.text2word = lambda x: list(x.replace(" ", ""))
def piece_to_id(self, piece: str) -> int:
try:
return self._sym2id[piece]
except KeyError:
return self.oov_id
def id_to_piece(self, id: int) -> str:
return self._id2sym[id]
def get_piece_size(self) -> int:
return len(self._sym2id)
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
return [
[i if i in self._sym2id else self.oov for i in self.text2word(text)]
for text in input
]
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
return [self.sep.join(text) for text in input]
def SplitBatch(self, input: List[str]) -> List[List[str]]:
return [self.text2word(text) for text in input]
def test_CharTokenizer():
test_single_string = "こんにちは"
test_multiple_string = [
"今日はいい天気ですよね",
"諏訪湖は綺麗でしょう",
"这在词表外",
"分かち 書き に し た 文章 です",
"",
]
test_empty_string = ""
sp = Tokenizer.load(Path("lang_char"), "char", oov="<unk>")
splitter = sp.split
print(sp.encode(test_single_string, out_type=str))
print(sp.encode(test_single_string, out_type=int))
print(sp.encode(test_multiple_string, out_type=str))
print(sp.encode(test_multiple_string, out_type=int))
print(sp.encode(test_empty_string, out_type=str))
print(sp.encode(test_empty_string, out_type=int))
print(sp.decode(sp.encode(test_single_string, out_type=str)))
print(sp.decode(sp.encode(test_single_string, out_type=int)))
print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
print(sp.decode(sp.encode(test_empty_string, out_type=str)))
print(sp.decode(sp.encode(test_empty_string, out_type=int)))
print(splitter(test_single_string))
print(splitter(test_multiple_string))
print(splitter(test_empty_string))
if __name__ == "__main__":
test_CharTokenizer()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py

View File

@ -1,17 +1,36 @@
# Introduction
A bilingual Japanese-English ASR model that utilizes ReazonSpeech, developed by the developers of ReazonSpeech.
A bilingual Japanese-English ASR model developed by the developers of ReazonSpeech that utilizes ReazonSpeech and the English subset of Multilingual LibriSpeech (MLS English), .
**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio.
**Multilingual LibriSpeech (MLS)** is a large multilingual corpus suitable for speech research. The dataset is derived from read audiobooks from LibriVox and consists of 8 languages - English, German, Dutch, Spanish, French, Italian, Portuguese, Polish. It includes about 44.5K hours of English and a total of about 6K hours for other languages. This icefall training recipe was created for the restructured version of the English split of the dataset available on Hugging Face from `parler-tts` [here](https://huggingface.co/datasets/parler-tts/mls_eng).
# Included Training Sets
1. LibriSpeech (English)
2. ReazonSpeech (Japanese)
# Training Sets
1. ReazonSpeech (Japanese)
2. Multilingual LibriSpeech (English)
|Datset| Number of hours| URL|
|---|---:|---|
|**TOTAL**|35,960|---|
|LibriSpeech|960|https://www.openslr.org/12/|
|ReazonSpeech (all) |35,000|https://huggingface.co/datasets/reazon-research/reazonspeech|
|**TOTAL**|79,500|---|
|MLS English|44,500|https://huggingface.co/datasets/parler-tts/mls_eng|
|ReazonSpeech (all)|35,000|https://huggingface.co/datasets/reazon-research/reazonspeech|
# Usage
This recipe relies on the `mls_english` recipe and the `reazonspeech` recipe.
To be able to use the `multi_ja_en` recipe, you must first run the `prepare.sh` scripts in both the `mls_english` recipe and the `reazonspeech` recipe.
This recipe does not enforce data balance: please ensure that the `mls_english` and `reazonspeech` datasets prepared above are balanced to your liking (you may use the utility script `create_subsets_greedy.py` in the `mls_english` recipe to create a custom-sized MLS English sub-dataset).
Steps for model training:
0. Run `../../mls_english/ASR/prepare.sh` and `../../reazonspeech/ASR/prepare.sh`
1. Run `./prepare.sh`
2. Run `update_cutset_paths.py` (we will soon add this to `./prepare.sh`)
3. Run `zipformer/train.py` (see example arguments inside the file)

View File

@ -2,51 +2,163 @@
### Zipformer
#### Non-streaming
#### Non-streaming (Byte-Level BPE vocab_size=2000)
Trained on 15k hours of ReazonSpeech (filtered to only audio segments between 8s and 22s) and 15k hours of MLS English.
The training command is:
```shell
./zipformer/train.py \
--bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--world-size 8 \
--causal 1 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--max-duration 600
--manifest-dir data/manifests \
--enable-musan True
```
The decoding command is:
```shell
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--epoch 10 \
--avg 1 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
--decoding-method modified_beam_search \
--manifest-dir data/manifests
```
To export the model with onnx:
```shell
./zipformer/export-onnx.py --tokens data/lang_bbpe_2000/tokens.txt --use-averaged-model 0 --epoch 35 --avg 1 --exp-dir zipformer/exp --num-encoder-layers "2,2,3,4,3,2" --downsampling-factor "1,2,4,8,4,2" --feedforward-dim "512,768,1024,1536,1024,768" --num-heads "4,4,4,8,4,4" --encoder-dim "192,256,384,512,384,256" --query-head-dim 32 --value-head-dim 12 --pos-head-dim 4 --pos-dim 48 --encoder-unmasked-dim "192,192,256,256,256,192" --cnn-module-kernel "31,31,15,15,15,31" --decoder-dim 512 --joiner-dim 512 --causal False --chunk-size "16,32,64,-1" --left-context-frames "64,128,256,-1" --fp16 True
./zipformer/export-onnx.py \
--tokens ./data/lang/bbpe_2000/tokens.txt \
--use-averaged-model 0 \
--epoch 10 \
--avg 1 \
--exp-dir ./zipformer/exp
```
WER and CER on test set listed below (calculated with `./zipformer/decode.py`):
| Datasets | ReazonSpeech + MLS English (combined test set) |
|----------------------|------------------------------------------------|
| Zipformer WER (%) | test |
| greedy_search | 6.33 |
| modified_beam_search | 6.32 |
We also include WER% for common English ASR datasets:
| Corpus | WER (%) |
|-----------------------------|---------|
| CommonVoice | 29.03 |
| TED | 16.78 |
| MLS English (test set) | 8.64 |
And CER% for common Japanese datasets:
| Corpus | CER (%) |
|---------------|---------|
| JSUT | 8.13 |
| CommonVoice | 9.82 |
| TEDx | 11.64 |
Pre-trained model can be found here: [https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k](https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k)
(Not yet publicly released)
#### Streaming (Byte-Level BPE vocab_size=2000)
Trained on 15k hours of ReazonSpeech (filtered to only audio segments between 8s and 22s) and 15k hours of MLS English.
The training command is:
```shell
./zipformer/train.py \
--world-size 8 \
--causal 1 \
--num-epochs 10 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--manifest-dir data/manifests \
--enable-musan True
```
The decoding command is:
```shell
TODO
```
To export the model with sherpa onnx:
```shell
./zipformer/export-onnx-streaming.py \
--tokens ./data/lang/bbpe_2000/tokens.txt \
--use-averaged-model 0 \
--epoch 10 \
--avg 1 \
--exp-dir ./zipformer/exp-15k15k-streaming \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 128 \
--fp16 True
```
(Adjust the `chunk-size` and `left-context-frames` as necessary)
To export the model as Torchscript (`.jit`):
```shell
./zipformer/export.py \
--exp-dir ./zipformer/exp-15k15k-streaming \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang/bbpe_2000/tokens.txt \
--epoch 10 \
--avg 1 \
--jit 1
```
You may also use decode chunk sizes `16`, `32`, `64`, `128`.
Word Error Rates (WERs) listed below:
| Datasets | ReazonSpeech | ReazonSpeech | LibriSpeech | LibriSpeech |
|----------------------|--------------|---------------|--------------------|-------------------|
| Zipformer WER (%) | dev | test | test-clean | test-other |
| greedy_search | 5.9 | 4.07 | 3.46 | 8.35 |
| modified_beam_search | 4.87 | 3.61 | 3.28 | 8.07 |
*Please let us know which script to use to evaluate the streaming model!*
Character Error Rates (CERs) for Japanese listed below:
| Decoding Method | In-Distribution CER | JSUT | CommonVoice | TEDx |
| :------------------: | :-----------------: | :--: | :---------: | :---: |
| greedy search | 12.56 | 6.93 | 9.75 | 9.67 |
| modified beam search | 11.59 | 6.97 | 9.55 | 9.51 |
We also include WER% for common English ASR datasets:
Pre-trained model can be found here: https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/main
*Please let us know which script to use to evaluate the streaming model!*
And CER% for common Japanese datasets:
*Please let us know which script to use to evaluate the streaming model!*
Pre-trained model can be found here: [https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k](https://huggingface.co/reazon-research/reazonspeech-k2-v2-ja-en/tree/multi_ja_en_15k15k)
(Not yet publicly released)

View File

@ -21,7 +21,7 @@
This script takes as input `lang_dir`, which should contain::
- lang_dir/bbpe.model,
- lang_dir/bbpe_2000/bbpe.model
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
@ -173,7 +173,8 @@ def get_args():
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
It should contain the words.txt file and the
bbpe model in a subdirectory (e.g., bbpe_2000/bbpe.model).
""",
)
@ -184,6 +185,13 @@ def get_args():
help="The out of vocabulary word in lexicon.",
)
parser.add_argument(
"--vocab-size",
type=int,
default=2000, # Add a default value for vocab_size for consistency
help="Vocabulary size used for BPE training (determines the bbpe model directory).",
)
parser.add_argument(
"--debug",
type=str2bool,
@ -206,6 +214,9 @@ def main():
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bbpe.model"
if not model_file.is_file():
raise FileNotFoundError(f"BPE model not found at: {model_file}")
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
@ -216,7 +227,7 @@ def main():
if w in words:
words.remove(w)
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
lexicon, token_sym_table = generate_lexicon(str(model_file), words, args.oov)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)

View File

@ -1,75 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from lhotse import CutSet
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"train_cut", metavar="train-cut", type=Path, help="Path to the train cut"
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help=(
"Name of lang dir. "
"If not set, this will default to lang_char_{trans-mode}"
),
)
return parser.parse_args()
def main():
args = get_args()
logging.basicConfig(
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
level=logging.INFO,
)
sysdef_string = set(["<blk>", "<unk>", "<sos/eos>", " "])
token_set = set()
logging.info(f"Creating vocabulary from {args.train_cut}.")
train_cut: CutSet = CutSet.from_file(args.train_cut)
for cut in train_cut:
for sup in cut.supervisions:
token_set.update(sup.text)
token_set = ["<blk>"] + sorted(token_set - sysdef_string) + ["<unk>", "<sos/eos>"]
args.lang_dir.mkdir(parents=True, exist_ok=True)
(args.lang_dir / "tokens.txt").write_text(
"\n".join(f"{t}\t{i}" for i, t in enumerate(token_set))
)
(args.lang_dir / "lang_type").write_text("char")
logging.info("Done.")
if __name__ == "__main__":
main()

View File

@ -33,7 +33,7 @@ from pathlib import Path
import sentencepiece as spm
from icefall import byte_encode
from icefall.utils import tokenize_by_ja_char
from icefall.utils import str2bool, tokenize_by_ja_char
def get_args():
@ -41,9 +41,7 @@ def get_args():
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
help="""Input directory.""",
)
parser.add_argument(
@ -58,6 +56,27 @@ def get_args():
help="Vocabulary size for BPE training",
)
parser.add_argument(
"--output-model",
type=str,
help="Path to save the trained BPE model.",
required=True,
)
parser.add_argument(
"--input-sentence-size",
type=int,
default=1000000, # Added default value
help="Maximum number of sentences to load for BPE training.",
)
parser.add_argument(
"--shuffle-input-sentence",
type=str2bool,
default=True, # Added default value
help="Whether to shuffle input sentences.",
)
return parser.parse_args()
@ -71,17 +90,20 @@ def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
output_model = Path(args.output_model)
input_sentence_size = args.input_sentence_size
shuffle_input_sentence = args.shuffle_input_sentence
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
model_file = Path(model_prefix + ".model")
if model_file.is_file():
print(f"{model_file} exists - skipping")
model_prefix = str(output_model.parent / f"{model_type}_{vocab_size}")
temp_model_file = Path(model_prefix + ".model")
if output_model.is_file():
print(f"{output_model} exists - skipping")
return
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
@ -100,6 +122,7 @@ def main():
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
shuffle_input_sentence=shuffle_input_sentence,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
@ -107,7 +130,7 @@ def main():
eos_id=-1,
)
shutil.copyfile(model_file, f"{lang_dir}/bbpe.model")
shutil.move(str(temp_model_file), str(output_model))
if __name__ == "__main__":

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
@ -23,6 +22,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
@ -34,12 +34,21 @@ from lhotse.dataset import (
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class ReazonSpeechAsrDataModule:
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class MultiDatasetAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
@ -70,7 +79,7 @@ class ReazonSpeechAsrDataModule:
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
default=Path("data/manifests"),
help="Path to directory with train/dev/test cuts.",
)
group.add_argument(
@ -192,6 +201,32 @@ class ReazonSpeechAsrDataModule:
transforms = []
input_transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan/musan_cuts.jsonl.gz"
)
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
@ -250,6 +285,8 @@ class ReazonSpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
@ -265,12 +302,17 @@ class ReazonSpeechAsrDataModule:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
seed = 42
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
pin_memory=True,
num_workers=self.args.num_workers,
persistent_workers=False,
persistent_workers=True,
worker_init_fn=worker_init_fn,
)
return train_dl
@ -332,24 +374,3 @@ class ReazonSpeechAsrDataModule:
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,156 @@
import logging
import os # Import os module to handle symlinks
from pathlib import Path
from lhotse import CutSet, load_manifest
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def update_paths(cuts: CutSet, dataset_name: str, old_feature_prefix: str) -> CutSet:
"""
Updates the storage_path in a CutSet's features to reflect the new dataset-specific
feature directory structure.
Args:
cuts: The Lhotse CutSet to modify.
dataset_name: The name of the dataset (e.g., "reazonspeech", "mls_english")
which corresponds to the new subdirectory for features.
old_feature_prefix: The prefix that the original feature paths were relative to.
This typically corresponds to the root of the manifests dir
in the original recipe.
"""
updated_cuts = []
for cut in cuts:
if cut.features is not None:
original_storage_path = Path(cut.features.storage_path)
try:
relative_path = original_storage_path.relative_to(old_feature_prefix)
except ValueError:
# If for some reason the path doesn't start with old_feature_prefix,
# keep it as is. This can happen if some paths are already absolute or different.
logger.warning(
f"Feature path '{original_storage_path}' does not start with '{old_feature_prefix}'. Skipping update for this cut."
)
updated_cuts.append(cut)
continue
# Avoid double-nesting (e.g., reazonspeech/reazonspeech/...)
# Construct the new path: data/manifests/<dataset_name>/feats_train/feats-12.lca
if relative_path.parts[0] == dataset_name:
new_storage_path = Path("data/manifests") / relative_path
else:
new_storage_path = Path("data/manifests") / dataset_name / relative_path
logger.info(
f"Updating cut {cut.id}: {original_storage_path}{new_storage_path}"
)
new_storage_path.as_posix()
updated_cuts.append(cut)
else:
logger.warning(f"Skipping update for cut {cut.id}: has no features.")
updated_cuts.append(cut) # No features, or not a path we need to modify
return CutSet.from_cuts(updated_cuts)
if __name__ == "__main__":
# The root where the symlinked manifests are located in the multi_ja_en recipe
multi_recipe_manifests_root = Path("data/manifests")
# Define the datasets and their *specific* manifest file prefixes
dataset_manifest_prefixes = {
"reazonspeech": "reazonspeech_cuts",
"mls_english": "mls_eng_cuts",
}
splits = ["train", "dev", "test"]
# This is the path segment *inside* the original recipe's data/manifests
# that your features were stored under.
# e.g., if original path was /original/recipe/data/manifests/feats_train/file.lca
# then this is 'data/manifests'
original_feature_base_path = "data/manifests"
musan_manifest_path = multi_recipe_manifests_root / "musan" / "musan_cuts.jsonl.gz"
if musan_manifest_path.exists():
logger.info(f"Processing musan manifest: {musan_manifest_path}")
try:
musan_cuts = load_manifest(musan_manifest_path)
updated_musan_cuts = update_paths(
musan_cuts, "musan", old_feature_prefix="data/fbank"
)
# Make sure we're overwriting the correct path even if it's a symlink
if musan_manifest_path.is_symlink() or musan_manifest_path.exists():
logger.info(
f"Overwriting existing musan manifest at: {musan_manifest_path}"
)
os.unlink(musan_manifest_path)
updated_musan_cuts.to_file(musan_manifest_path)
logger.info(f"Updated musan cuts written to: {musan_manifest_path}")
except Exception as e:
logger.error(
f"Error processing musan manifest {musan_manifest_path}: {e}",
exc_info=True,
)
else:
logger.warning(f"Musan manifest not found at {musan_manifest_path}, skipping.")
for dataset_name, manifest_prefix in dataset_manifest_prefixes.items():
dataset_symlink_dir = multi_recipe_manifests_root / dataset_name
if not dataset_symlink_dir.is_dir():
logger.warning(
f"Dataset symlink directory not found: {dataset_symlink_dir}. Skipping {dataset_name}."
)
continue
for split in splits:
# Construct the path to the symlinked manifest file
manifest_filename = f"{manifest_prefix}_{split}.jsonl.gz"
symlink_path = (
dataset_symlink_dir / manifest_filename
) # This is the path to the symlink itself
if symlink_path.is_symlink(): # Check if it's actually a symlink
# Get the actual path to the target file that the symlink points to
# Lhotse's load_manifest will follow this symlink automatically.
target_path = os.path.realpath(symlink_path)
logger.info(
f"Processing symlink '{symlink_path}' pointing to '{target_path}'"
)
elif symlink_path.is_file(): # If it's a regular file (not a symlink)
logger.info(f"Processing regular file: {symlink_path}")
target_path = symlink_path # Use its own path as target
else:
logger.warning(
f"Manifest file not found or neither a file nor a symlink: {symlink_path}"
)
continue # Skip to next iteration
try:
# Load the manifest. Lhotse will resolve the symlink internally for reading.
cuts = load_manifest(
symlink_path
) # Use symlink_path here, Lhotse handles resolution for loading
# Update the storage_path within the loaded cuts (in memory)
updated_cuts = update_paths(
cuts, dataset_name, old_feature_prefix=original_feature_base_path
)
# --- CRITICAL CHANGE HERE ---
# Save the *modified* CutSet to the path of the symlink *itself*.
# This will overwrite the symlink with the new file, effectively
# breaking the symlink and creating a new file in its place.
os.unlink(symlink_path)
updated_cuts.to_file(symlink_path)
logger.info(
f"Updated {dataset_name} {split} cuts saved (overwriting symlink) to: {symlink_path}"
)
except Exception as e:
logger.error(f"Error processing {symlink_path}: {e}", exc_info=True)
logger.info("CutSet path updating complete.")

View File

@ -1 +1 @@
../../../librispeech/ASR/local/validate_bpe_lexicon.py
/root/Github/reazon-icefall/egs/librispeech/ASR/local/validate_bpe_lexicon.py

View File

@ -19,6 +19,8 @@ vocab_sizes=(
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
mkdir -p data/lang
mkdir -p data/manifests
log() {
# This function is from espnet
@ -31,55 +33,54 @@ log "dl_dir: $dl_dir"
log "Dataset: musan"
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Soft link fbank of musan"
mkdir -p data/fbank
if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then
cd data/fbank
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) .
cd ../..
cd data/manifests
mkdir -p musan
cd musan
ln -svfr $(realpath ../../../../../librispeech/ASR/data/fbank/musan_feats) .
ln -svfr $(realpath ../../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) .
cd ../../..
else
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4"
exit 1
fi
fi
log "Dataset: LibriSpeech"
log "Dataset: MLS English"
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 1: Soft link fbank of LibriSpeech"
mkdir -p data/fbank
if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
cd data/fbank
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) .
cd ../..
log "Stage 2: Soft link manifests (including fbank) of MLS English"
if [ -e ../../mls_english/ASR/data/manifests/.mls_english-validated.done ]; then
cd data/manifests
mkdir -p mls_english
cd mls_english
ln -svfr $(realpath ../../../../../mls_english/ASR/data/manifests/mls_eng_cuts*) .
ln -svfr $(realpath ../../../../../mls_english/ASR/data/manifests/feats*) .
cd ../../..
else
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 1 --stop-stage 1 and ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"
log "Abort! Please run ../../mls_english/ASR/prepare.sh --stage 1 --stop-stage 1"
exit 1
fi
fi
log "Dataset: ReazonSpeech"
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Soft link fbank of ReazonSpeech"
mkdir -p data/fbank
log "Stage 3: Soft link fbank of ReazonSpeech"
if [ -e ../../reazonspeech/ASR/data/manifests/.reazonspeech.done ]; then
cd data/fbank
ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/reazonspeech_cuts*) .
cd ..
mkdir -p manifests
cd manifests
ln -svf $(realpath ../../../../reazonspeech/ASR/data/manifests/feats_*) .
cd ../..
cd data/manifests
mkdir -p reazonspeech
cd reazonspeech
ln -svfr $(realpath ../../../../../reazonspeech/ASR/data/manifests/reazonspeech_cuts*) .
ln -svfr $(realpath ../../../../../reazonspeech/ASR/data/manifests/feats*) .
cd ../../..
else
log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 0 --stop-stage 2"
exit 1
fi
fi
# New Stage 3: Prepare char based lang for ReazonSpeech
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
lang_char_dir=data/lang_char
log "Stage 3: Prepare char based lang for ReazonSpeech"
log "Stage 4: Prepare char-based lang for ReazonSpeech"
mkdir -p $lang_char_dir
# Prepare text
@ -89,7 +90,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
| ./local/text2token.py -t "char" > $lang_char_dir/text
fi
# jp word segmentation for text
# Japanese word segmentation
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \
--input-file $lang_char_dir/text \
@ -106,80 +107,96 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ ! -f $lang_char_dir/L_disambig.pt ]; then
python3 ./local/prepare_char.py --lang-dir data/lang_char
python3 ./local/prepare_char.py --lang-dir $lang_char_dir
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare Byte BPE based lang"
mkdir -p data/fbank
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare Byte BPE based lang in data/lang"
lang_dir=data/lang
# Check if required char-based lang data exists
if [ ! -d ../../reazonspeech/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
log "Abort! Please run ../../reazonspeech/ASR/prepare.sh --stage 3 --stop-stage 3"
exit 1
fi
if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5"
# Check if BPE data from MLS English exists
if [ ! -d ../../mls_english/ASR/data/lang/bpe_2000 ] || [ ! -f ../../mls_english/ASR/data/lang/transcript.txt ]; then
log "Abort! Please ensure ../../mls_english/ASR/data/lang/bpe_2000 and ../../mls_english/ASR/data/lang/transcript.txt exist."
log "Please run ../../mls_english/ASR/prepare.sh --stage 3 --stop-stage 3 if you haven't already."
exit 1
fi
cd data/
# if [ ! -d ./lang_char ]; then
# ln -svf $(realpath ../../../reazonspeech/ASR/data/lang_char) .
# fi
if [ ! -d ./lang_bpe_500 ]; then
ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
fi
cd ../
# Create the target lang directory if it doesn't exist
mkdir -p $lang_dir
# Combine Japanese char-level text and English BPE transcript
cat data/lang_char/text ../../mls_english/ASR/data/lang/transcript.txt \
> $lang_dir/text
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bbpe_${vocab_size}
mkdir -p $lang_dir
bbpe_dir=$lang_dir/bbpe_${vocab_size}
mkdir -p $bbpe_dir
cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
> $lang_dir/text
if [ ! -f $lang_dir/transcript_chars.txt ]; then
if [ ! -f $bbpe_dir/transcript_chars.txt ]; then
./local/prepare_for_bpe_model.py \
--lang-dir ./$lang_dir \
--lang-dir $bbpe_dir \
--text $lang_dir/text
fi
if [ ! -f $lang_dir/text_words_segmentation ]; then
if [ ! -f $bbpe_dir/text_words_segmentation ]; then
python3 ./local/text2segments.py \
--input-file ./data/lang_char/text \
--output-file $lang_dir/text_words_segmentation
cat ./data/lang_bpe_500/transcript_words.txt \
>> $lang_dir/text_words_segmentation
--output-file $bbpe_dir/text_words_segmentation
cat ../../mls_english/ASR/data/lang/transcript.txt \
>> $bbpe_dir/text_words_segmentation
fi
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt
if [ ! -f $bbpe_dir/words_no_ids.txt ]; then
cat $bbpe_dir/text_words_segmentation | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' | uniq > $bbpe_dir/words_no_ids.txt
fi
if [ ! -f $lang_dir/words.txt ]; then
if [ ! -f $bbpe_dir/words.txt ]; then
python3 ./local/prepare_words.py \
--input-file $lang_dir/words_no_ids.txt \
--output-file $lang_dir/words.txt
--input-file $bbpe_dir/words_no_ids.txt \
--output-file $bbpe_dir/words.txt
fi
if [ ! -f $lang_dir/bbpe.model ]; then
if [ ! -f $bbpe_dir/bbpe.model ]; then
./local/train_bbpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/text
--transcript $lang_dir/text \
--output-model $bbpe_dir/bbpe.model \
--input-sentence-size 2000000 # Example: limit to 2 million sentences
fi
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bbpe.py --lang-dir $lang_dir
if [ ! -f $bbpe_dir/L_disambig.pt ]; then
./local/prepare_lang_bbpe.py --lang-dir $bbpe_dir --vocab-size $vocab_size
log "Validating $lang_dir/lexicon.txt"
ln -svf $(realpath ../../multi_zh_en/ASR/local/validate_bpe_lexicon.py) local/
log "Validating $bbpe_dir/lexicon.txt"
ln -svfr $(realpath ../../multi_zh_en/ASR/local/validate_bpe_lexicon.py) local/
./local/validate_bpe_lexicon.py \
--lexicon $lang_dir/lexicon.txt \
--bpe-model $lang_dir/bbpe.model
--lexicon $bbpe_dir/lexicon.txt \
--bpe-model $bbpe_dir/bbpe.model
fi
# Remove top-level files (if they were created)
rm -f $lang_dir/lexicon.txt $lang_dir/L_disambig.pt
done
# Optional symlink
if [ -d $lang_dir/bbpe_2000 ] && [ ! -e $lang_dir/bpe_2000 ]; then
ln -sr bbpe_2000 $lang_dir/bpe_2000
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Update cutset paths"
python local/utils/update_cutset_paths.py
fi
log "prepare.sh: PREPARATION DONE"

View File

@ -68,7 +68,7 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -157,14 +157,14 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_2000/bbpe.model",
default="data/lang/bbpe_2000/bbpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bbpe_2000",
default="data/lang/bbpe_2000",
help="The lang dir containing word table and LG graph",
)
@ -573,7 +573,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -748,7 +748,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
data_module = ReazonSpeechAsrDataModule(args)
multidataset_datamodule = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args)
def remove_short_utt(c: Cut):
@ -759,31 +759,42 @@ def main():
)
return T > 0
test_sets_cuts = multi_dataset.test_cuts()
def tokenize_and_encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = byte_encode(tokenize_by_ja_char(text))
c.supervisions[0].text = text
return c
test_sets = test_sets_cuts.keys()
test_dl = [
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
for cuts_name in test_sets
]
test_cuts = multi_dataset.test_cuts()
test_cuts = test_cuts.filter(remove_short_utt)
# test_cuts = test_cuts.map(tokenize_and_encode_text)
for test_set, test_dl in zip(test_sets, test_dl):
logging.info(f"Start decoding test set: {test_set}")
test_dl = multidataset_datamodule.test_dataloaders(test_cuts)
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
# test_sets = test_sets_cuts.keys()
# test_dl = [
# data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
# for cuts_name in test_sets
# ]
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
# for test_set, test_dl in zip(test_sets, test_dl):
logging.info("Start decoding test set") #: {test_set}")
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name="test_set",
results_dict=results_dict,
)
logging.info("Done!")

View File

@ -57,7 +57,7 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -1085,8 +1085,8 @@ def run(rank, world_size, args):
return True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
train_cuts = reazonspeech_corpus.train_cuts()
multidataset_datamodule = MultiDatasetAsrDataModule(args)
train_cuts = multidataset_datamodule.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1097,12 +1097,12 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = reazonspeech_corpus.train_dataloaders(
train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = reazonspeech_corpus.valid_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
valid_cuts = multidataset_datamodule.valid_cuts()
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
@ -1242,7 +1242,7 @@ def scan_pessimistic_batches_for_oom(
def main():
raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()

View File

@ -13,36 +13,36 @@ class MultiDataset:
Args:
manifest_dir:
It is expected to contain the following files:
- reazonspeech_cuts_train.jsonl.gz
- librispeech_cuts_train-clean-100.jsonl.gz
- librispeech_cuts_train-clean-360.jsonl.gz
- librispeech_cuts_train-other-500.jsonl.gz
- mls_english/
- mls_eng_cuts_train.jsonl.gz
- mls_eng_cuts_dev.jsonl.gz
- mls_eng_cuts_test.jsonl.gz
- reazonspeech/
- reazonspeech_cuts_train.jsonl.gz
- reazonspeech_cuts_dev.jsonl.gz
- reazonspeech_cuts_test.jsonl.gz
"""
self.fbank_dir = Path(args.manifest_dir)
self.manifest_dir = Path(args.manifest_dir)
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
logging.info("Loading Reazonspeech in lazy mode")
reazonspeech_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_train.jsonl.gz"
logging.info("Loading Reazonspeech TRAIN set in lazy mode")
reazonspeech_train_cuts = load_manifest_lazy(
self.manifest_dir / "reazonspeech/reazonspeech_cuts_train.jsonl.gz"
)
logging.info("Loading LibriSpeech in lazy mode")
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
logging.info("Loading MLS English TRAIN set in lazy mode")
mls_eng_train_cuts = load_manifest_lazy(
self.manifest_dir / "mls_english/mls_eng_cuts_train.jsonl.gz"
)
return CutSet.mux(
reazonspeech_cuts,
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
reazonspeech_train_cuts,
mls_eng_train_cuts,
weights=[
len(reazonspeech_cuts),
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
len(reazonspeech_train_cuts),
len(mls_eng_train_cuts),
],
)
@ -51,93 +51,90 @@ class MultiDataset:
logging.info("Loading Reazonspeech DEV set in lazy mode")
reazonspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
self.manifest_dir / "reazonspeech/reazonspeech_cuts_dev.jsonl.gz"
)
logging.info("Loading LibriSpeech DEV set in lazy mode")
dev_clean_cuts = self.dev_clean_cuts()
dev_other_cuts = self.dev_other_cuts()
logging.info("Loading MLS English DEV set in lazy mode")
mls_eng_dev_cuts = load_manifest_lazy(
self.manifest_dir / "mls_english/mls_eng_cuts_dev.jsonl.gz"
)
return CutSet.mux(
reazonspeech_dev_cuts,
dev_clean_cuts,
dev_other_cuts,
mls_eng_dev_cuts,
weights=[
len(reazonspeech_dev_cuts),
len(dev_clean_cuts),
len(dev_other_cuts),
len(mls_eng_dev_cuts),
],
)
def test_cuts(self) -> Dict[str, CutSet]:
def test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
logging.info("Loading Reazonspeech set in lazy mode")
logging.info("Loading Reazonspeech TEST set in lazy mode")
reazonspeech_test_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_test.jsonl.gz"
)
reazonspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "reazonspeech_cuts_dev.jsonl.gz"
self.manifest_dir / "reazonspeech/reazonspeech_cuts_test.jsonl.gz"
)
logging.info("Loading LibriSpeech set in lazy mode")
test_clean_cuts = self.test_clean_cuts()
test_other_cuts = self.test_other_cuts()
test_cuts = {
"reazonspeech_test": reazonspeech_test_cuts,
"reazonspeech_dev": reazonspeech_dev_cuts,
"librispeech_test_clean": test_clean_cuts,
"librispeech_test_other": test_other_cuts,
}
return test_cuts
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
logging.info("Loading MLS English TEST set in lazy mode")
mls_eng_test_cuts = load_manifest_lazy(
self.manifest_dir / "mls_english/mls_eng_cuts_test.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
return CutSet.mux(
reazonspeech_test_cuts,
mls_eng_test_cuts,
weights=[
len(reazonspeech_test_cuts),
len(mls_eng_test_cuts),
],
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
# @lru_cache()
# def train_clean_100_cuts(self) -> CutSet:
# logging.info("About to get train-clean-100 cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
# )
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
# @lru_cache()
# def train_clean_360_cuts(self) -> CutSet:
# logging.info("About to get train-clean-360 cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
# )
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
# @lru_cache()
# def train_other_500_cuts(self) -> CutSet:
# logging.info("About to get train-other-500 cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
# )
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
# @lru_cache()
# def dev_clean_cuts(self) -> CutSet:
# logging.info("About to get dev-clean cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
# )
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz"
)
# @lru_cache()
# def dev_other_cuts(self) -> CutSet:
# logging.info("About to get dev-other cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
# )
# @lru_cache()
# def test_clean_cuts(self) -> CutSet:
# logging.info("About to get test-clean cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
# )
# @lru_cache()
# def test_other_cuts(self) -> CutSet:
# logging.info("About to get test-other cuts")
# return load_manifest_lazy(
# self.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
# )

View File

@ -63,7 +63,7 @@ import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@ -740,7 +740,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -887,7 +887,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
multidataset_datamodule = MultiDatasetAsrDataModule(args)
if params.bilingual:
multi_dataset = MultiDataset(args)
@ -904,8 +904,8 @@ def main():
test_sets = test_sets_cuts.keys()
test_cuts = [test_sets_cuts[k] for k in test_sets]
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
valid_cuts = multidataset_datamodule.valid_cuts()
test_cuts = multidataset_datamodule.test_cuts()
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]

View File

@ -25,7 +25,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -35,7 +34,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For streaming model training:
./zipformer/train.py \
--bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -50,6 +48,7 @@ It supports training with:
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
"""
import argparse
import copy
import logging
@ -66,7 +65,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from asr_datamodule import MultiDatasetAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -77,7 +76,6 @@ from multi_dataset import MultiDataset
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from tokenizer import Tokenizer
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -269,13 +267,6 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bilingual",
type=str2bool,
default=False,
help="Whether the model is bilingual or not. 1 = bilingual.",
)
parser.add_argument(
"--world-size",
type=int,
@ -333,11 +324,10 @@ def get_parser():
""",
)
# changed - not used in monolingual streaming
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bbpe_2000/bbpe.model",
default="data/lang/bbpe_2000/bbpe.model",
help="Path to the BPE model",
)
@ -763,11 +753,9 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
@ -803,10 +791,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
if not params.bilingual:
y = tokenizer.encode(texts, out_type=int)
else:
y = sentencepiece_processor.encode(texts, out_type=int)
y = sentencepiece_processor.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
@ -862,7 +847,6 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
@ -876,7 +860,6 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=False,
@ -900,7 +883,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
@ -972,7 +954,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=True,
@ -993,7 +974,6 @@ def train_one_epoch(
display_and_save_batch(
batch,
params=params,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
)
raise
@ -1082,7 +1062,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
valid_dl=valid_dl,
world_size=world_size,
@ -1136,25 +1115,12 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
# Use lang_dir for further operations
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
# sentencepiece_processor = spm.SentencePieceProcessor()
# sentencepiece_processor.load(params.bpe_model)
tokenizer = None
sentencepiece_processor = None
sentencepiece_processor = spm.SentencePieceProcessor()
sentencepiece_processor.load(params.bpe_model)
# <blk> is defined in local/prepare_lang_char.py
if not params.bilingual:
tokenizer = Tokenizer.load(args.lang, args.lang_type)
params.blank_id = tokenizer.piece_to_id("<blk>")
params.vocab_size = tokenizer.get_piece_size()
else:
sentencepiece_processor = spm.SentencePieceProcessor()
sentencepiece_processor.load(params.bpe_model)
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
params.vocab_size = sentencepiece_processor.get_piece_size()
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
params.vocab_size = sentencepiece_processor.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@ -1212,27 +1178,24 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
if params.bilingual:
multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts()
else:
train_cuts = reazonspeech_corpus.train_cuts()
multidataset_datamodule = MultiDatasetAsrDataModule(args)
multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
# Keep only utterances greater than 1 second
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
# if c.duration < 1.0 or c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
# return False
# the threshold as this is dependent on which datasets you choose
if c.duration < 1.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
@ -1240,18 +1203,13 @@ def run(rank, world_size, args):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_samples - 7) // 2 + 1) // 2
if not params.bilingual:
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
else:
tokens = sentencepiece_processor.encode(
c.supervisions[0].text, out_type=str
)
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sentencepiece_processor.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_samples}. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
@ -1270,8 +1228,7 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.bilingual:
train_cuts = train_cuts.map(tokenize_and_encode_text)
train_cuts = train_cuts.map(tokenize_and_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1280,22 +1237,19 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = reazonspeech_corpus.train_dataloaders(
train_dl = multidataset_datamodule.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.bilingual:
valid_cuts = reazonspeech_corpus.valid_cuts()
else:
valid_cuts = multi_dataset.dev_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
valid_cuts = multi_dataset.dev_cuts()
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
params=params,
)
@ -1321,7 +1275,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
train_dl=train_dl,
valid_dl=valid_dl,
@ -1356,7 +1309,6 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
@ -1367,10 +1319,8 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
tokenizer:
The BPE Tokenizer model.
sentencepiece_processor:
The BPE SentencePieceProcessor model.
The BPE model.
"""
from lhotse.utils import uuid4
@ -1382,11 +1332,7 @@ def display_and_save_batch(
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
if params.bilingual:
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
else:
y = tokenizer.encode(supervisions["text"], out_type=int)
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
@ -1395,7 +1341,6 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
params: AttributeDict,
):
@ -1412,7 +1357,6 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=True,
@ -1431,7 +1375,6 @@ def scan_pessimistic_batches_for_oom(
display_and_save_batch(
batch,
params=params,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
)
raise
@ -1442,8 +1385,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
MultiDatasetAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)

View File

@ -94,12 +94,14 @@ def compute_fbank_musan(
logging.info("Extracting features for Musan")
if whisper_fbank:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
logging.warning("CUDA not available; using WhisperFbank on CPU.")
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
WhisperFbankConfig(num_filters=num_mel_bins, device=device)
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (