From 7cbd6d11bab050da6f832ba4a81c63c2551542f5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 16 Feb 2022 12:27:48 +0800 Subject: [PATCH] Finish preparing training datasets. --- egs/librispeech/ASR/prepare_giga_speech.sh | 109 ++++++++++++++++++ .../{dataset.py => asr_datamodule.py} | 84 ++++++++++---- .../gigaspeech.py | 44 ++++--- .../librispeech.py | 2 +- .../test_asr_datamodule.py | 103 +++++++++++++++++ 5 files changed, 305 insertions(+), 37 deletions(-) create mode 100755 egs/librispeech/ASR/prepare_giga_speech.sh rename egs/librispeech/ASR/transducer_stateless_multi_datasets/{dataset.py => asr_datamodule.py} (75%) create mode 100755 egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py diff --git a/egs/librispeech/ASR/prepare_giga_speech.sh b/egs/librispeech/ASR/prepare_giga_speech.sh new file mode 100755 index 000000000..49124c4d7 --- /dev/null +++ b/egs/librispeech/ASR/prepare_giga_speech.sh @@ -0,0 +1,109 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/GigaSpeech +# You can find audio, dict, GigaSpeech.json inside it. +# You can apply for the download credentials by following +# https://github.com/SpeechColab/GigaSpeech#download + +# Number of hours for GigaSpeech subsets +# XL 10k hours +# L 2.5k hours +# M 1k hours +# S 250 hours +# XS 10 hours +# DEV 12 hours +# Test 40 hours + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + [ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech + + # If you have pre-downloaded it to /path/to/GigaSpeech, + # you can create a symlink + # + # ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech + # + if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then + # Check credentials. + if [ ! -f $dl_dir/password ]; then + echo -n "$0: Please apply for the download credentials by following" + echo -n "https://github.com/SpeechColab/GigaSpeech#dataset-download" + echo " and save it to $dl_dir/password." + exit 1; + fi + PASSWORD=`cat $dl_dir/password 2>/dev/null` + if [ -z "$PASSWORD" ]; then + echo "$0: Error, $dl_dir/password is empty." + exit 1; + fi + PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1` + if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then + echo "$0: Error, invalid $dl_dir/password." + exit 1; + fi + # Download XL, DEV and TEST sets by default. + lhotse download gigaspeech \ + --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + --host tsinghua \ + $dl_dir/password $dl_dir/GigaSpeech + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare GigaSpeech manifest (may take 30 minutes)" + # We assume that you have downloaded the GigaSpeech corpus + # to $dl_dir/GigaSpeech + mkdir -p data/manifests + lhotse prepare gigaspeech \ + --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + -j $nj \ + $dl_dir/GigaSpeech data/manifests +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Preprocess GigaSpeech manifest" + if [ ! -f data/fbank/.preprocess_complete ]; then + log "It may take 2 hours for this stage" + python3 ./local/preprocess_gigaspeech.py + touch data/fbank/.preprocess_complete + fi +fi diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/dataset.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py similarity index 75% rename from egs/librispeech/ASR/transducer_stateless_multi_datasets/dataset.py rename to egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 59da11027..16daf2f1b 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/dataset.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -16,12 +16,28 @@ # limitations under the License. import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet, Fbank, FbankConfig +from lhotse.dataset import ( + BucketingSampler, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) +from torch.utils.data import DataLoader -from lhotse import CutSet from icefall.utils import str2bool -class AsrDataset: +class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -55,19 +71,11 @@ class AsrDataset: "--num-buckets", type=int, default=30, - help="The number of buckets for the BucketingSampler" + help="The number of buckets for the BucketingSampler " + "and DynamicBucketingSampler." "(you might want to increase it for larger datasets).", ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( "--shuffle", type=str2bool, @@ -126,8 +134,25 @@ class AsrDataset: ) def train_dataloaders( - self, cuts_train: CutSet, cuts_musan: Optional[CutSet] = None + self, + cuts_train: CutSet, + dynamic_bucketing: bool, + on_the_fly_feats: bool, + cuts_musan: Optional[CutSet] = None, ) -> DataLoader: + """ + Args: + cuts_train: + Cuts for training. + cuts_musan: + If not None, it is the cuts for mixing. + dynamic_bucketing: + True to use DynamicBucketingSampler; + False to use BucketingSampler. + on_the_fly_feats: + True to use OnTheFlyFeatures; + False to use PrecomputedFeatures. + """ transforms = [] if cuts_musan is not None: logging.info("Enable MUSAN") @@ -177,21 +202,34 @@ class AsrDataset: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if on_the_fly_feats + else PrecomputedFeatures() ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=True, - ) + if dynamic_bucketing: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) logging.info("About to create train dataloader") train_dl = DataLoader( diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py index cd358c416..286771d7d 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/gigaspeech.py @@ -17,7 +17,7 @@ import logging -from typing import Path +from pathlib import Path from lhotse import CutSet, load_manifest @@ -29,29 +29,47 @@ class GigaSpeech: manifest_dir: It is expected to contain the following files:: - - cuts_L.jsonl.gz - - cuts_XL.jsonl.gz - - cuts_TEST.jsonl.gz - - cuts_DEV.jsonl.gz + - cuts_XL_raw.jsonl.gz + - cuts_L_raw.jsonl.gz + - cuts_M_raw.jsonl.gz + - cuts_S_raw.jsonl.gz + - cuts_XS_raw.jsonl.gz + - cuts_DEV_raw.jsonl.gz + - cuts_TEST_raw.jsonl.gz """ self.manifest_dir = Path(manifest_dir) - def train_L_cuts(self) -> CutSet: - f = self.manifest_dir / "cuts_L.json.gz" - logging.info(f"About to get train-L cuts from {f}") - return CutSet.from_jsonl_lazy(f) - def train_XL_cuts(self) -> CutSet: - f = self.manifest_dir / "cuts_XL.json.gz" + f = self.manifest_dir / "cuts_XL_raw.jsonl.gz" logging.info(f"About to get train-XL cuts from {f}") return CutSet.from_jsonl_lazy(f) + def train_L_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_L_raw.jsonl.gz" + logging.info(f"About to get train-L cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_M_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_M_raw.jsonl.gz" + logging.info(f"About to get train-M cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_S_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_S_raw.jsonl.gz" + logging.info(f"About to get train-S cuts from {f}") + return CutSet.from_jsonl_lazy(f) + + def train_XS_cuts(self) -> CutSet: + f = self.manifest_dir / "cuts_XS_raw.jsonl.gz" + logging.info(f"About to get train-XS cuts from {f}") + return CutSet.from_jsonl_lazy(f) + def test_cuts(self) -> CutSet: - f = self.manifest_dir / "cuts_TEST.json.gz" + f = self.manifest_dir / "cuts_TEST.jsonl.gz" logging.info(f"About to get TEST cuts from {f}") return load_manifest(f) def dev_cuts(self) -> CutSet: - f = self.manifest_dir / "cuts_DEV.json.gz" + f = self.manifest_dir / "cuts_DEV.jsonl.gz" logging.info(f"About to get DEV cuts from {f}") return load_manifest(f) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py index f6f30aa0c..ecffcf9ff 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/librispeech.py @@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Path +from pathlib import Path from lhotse import CutSet, load_manifest diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py new file mode 100755 index 000000000..54f152a88 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless_multi_datasets/test_asr_datamodule.py +""" + +import argparse +import random +from pathlib import Path + +from asr_datamodule import AsrDataModule +from gigaspeech import GigaSpeech +from lhotse import load_manifest +from librispeech import LibriSpeech + + +def test_dataset(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + print(args) + + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "cuts_musan.json.gz" + ) + else: + cuts_musan = None + + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) + + train_clean_100 = librispeech.train_clean_100_cuts() + train_S = gigaspeech.train_S_cuts() + + asr_datamodule = AsrDataModule(args) + + libri_train_dl = asr_datamodule.train_dataloaders( + train_clean_100, + dynamic_bucketing=False, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + giga_train_dl = asr_datamodule.train_dataloaders( + train_S, + dynamic_bucketing=True, + on_the_fly_feats=True, + cuts_musan=cuts_musan, + ) + + seed = 20220216 + rng = random.Random(seed) + + for epoch in range(2): + print("epoch", epoch) + batch_idx = 0 + libri_train_dl.sampler.set_epoch(epoch) + giga_train_dl.sampler.set_epoch(epoch) + + iter_libri = iter(libri_train_dl) + iter_giga = iter(giga_train_dl) + while True: + idx = rng.choices((0, 1), weights=[0.8, 0.2], k=1)[0] + dl = iter_libri if idx == 0 else iter_giga + batch_idx += 1 + + print("dl idx", idx, "batch_idx", batch_idx) + batch = next(dl) + cuts = batch["supervisions"]["cut"] + for c in cuts: + print(c.id) + + if batch_idx > 10: + break + + +def main(): + test_dataset() + + +if __name__ == "__main__": + main()