diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index b26034eb2..5c33ff8be 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -85,7 +85,7 @@ def compute_fbank_librispeech(): # when an executor is specified, make more partitions num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index d44524e70..f5911746b 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -82,7 +82,7 @@ def compute_fbank_musan(): storage_path=f"{output_dir}/feats_musan", num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) ) musan_cuts.to_json(musan_cuts_path) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py new file mode 100644 index 000000000..4168a7185 --- /dev/null +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from pathlib import Path + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + +# Similar text filtering and normalization procedure as in: +# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh + + +def normalize_text( + utt: str, + punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), + whitespace_pattern=re.compile(r"\s\s+"), +) -> str: + return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) + + +def has_no_oov( + sup: SupervisionSegment, + oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"), +) -> bool: + return oov_pattern.search(sup.text) is None + + +def preprocess_giga_speech(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + dataset_parts = ( + "DEV", + "TEST", + "XS", + "S", + "M", + "L", + "XL", + ) + + logging.info("Loading manifest (may take 4 minutes)") + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix="gigaspeech", + suffix="jsonl.gz", + ) + assert manifests is not None + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + # Note this step makes the recipe different than LibriSpeech: + # We must filter out some utterances and remove punctuation + # to be consistent with Kaldi. + logging.info("Filtering OOV utterances from supervisions") + m["supervisions"] = m["supervisions"].filter(has_no_oov) + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + sup.text = normalize_text(sup.text) + sup.custom = {"origin": "giga"} + + # Create long-recording cut manifests. + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + # Run data augmentation that needs to be done in the + # time domain. + if partition not in ["DEV", "TEST"]: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 8 minutes and saving may take 20 minutes)" + ) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + preprocess_giga_speech() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/__init__.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/dataset.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/dataset.py new file mode 100644 index 000000000..59da11027 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/dataset.py @@ -0,0 +1,204 @@ +# Copyright 2021 Piotr Żelasko +# 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from lhotse import CutSet +from icefall.utils import str2bool + + +class AsrDataset: + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + + def train_dataloaders( + self, cuts_train: CutSet, cuts_musan: Optional[CutSet] = None + ) -> DataLoader: + transforms = [] + if cuts_musan is not None: + logging.info("Enable MUSAN") + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + input_transforms = [] + + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + + logging.info("About to create train dataloader") + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py new file mode 100644 index 000000000..cd358c416 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py @@ -0,0 +1,57 @@ +# Copyright 2021 Piotr Żelasko +# 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import Path + +from lhotse import CutSet, load_manifest + + +class GigaSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cuts_L.jsonl.gz + - cuts_XL.jsonl.gz + - cuts_TEST.jsonl.gz + - cuts_DEV.jsonl.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_L.json.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XL_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_XL.json.gz" + logging.info(f"About to get train-XL cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def test_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_TEST.json.gz" + logging.info(f"About to get TEST cuts from {f}") + return load_manifest(f) + + def dev_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_DEV.json.gz" + logging.info(f"About to get DEV cuts from {f}") + return load_manifest(f) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py new file mode 100644 index 000000000..f6f30aa0c --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py @@ -0,0 +1,74 @@ +# Copyright 2021 Piotr Żelasko +# 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Path + +from lhotse import CutSet, load_manifest + + +class LibriSpeech: + def __init__(self, manifest_dir: str): + """ + Args: + manifest_dir: + It is expected to contain the following files:: + + - cuts_dev-clean.json.gz + - cuts_dev-other.json.gz + - cuts_test-clean.json.gz + - cuts_test-other.json.gz + - cuts_train-clean-100.json.gz + - cuts_train-clean-360.json.gz + - cuts_train-other-500.json.gz + """ + self.manifest_dir = Path(manifest_dir) + + def train_clean_100_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_train-clean-100.json.gz" + logging.info(f"About to get train-clean-100 cuts from {f}") + return load_manifest(f) + + def train_clean_360_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_train-clean-360.json.gz" + logging.info(f"About to get train-clean-360 cuts from {f}") + return load_manifest(f) + + def train_other_500_cuts(self) -> CutSet: + f = self.args.manifest_dir / "cuts_train-other-500.json.gz" + logging.info(f"About to get train-other-500 cuts from {f}") + return load_manifest(f) + + def test_clean_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_test-clean.json.gz" + logging.info(f"About to get test-clean cuts from {f}") + return load_manifest(f) + + def test_other_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_test-other.json.gz" + logging.info(f"About to get test-other cuts from {f}") + return load_manifest(f) + + def dev_clean_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_dev-clean.json.gz" + logging.info(f"About to get dev-clean cuts from {f}") + return load_manifest(f) + + def dev_other_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_dev-other.json.gz" + logging.info(f"About to get dev-other cuts from {f}") + return load_manifest(f)