# Copyright 2021 Piotr Żelasko # 2022 Xiaomi Corp. (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging from pathlib import Path from typing import Optional from lhotse import CutSet, Fbank, FbankConfig from lhotse.dataset import ( CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures from torch.utils.data import DataLoader from icefall.utils import str2bool class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", description=( "These options are used for the preparation of " "PyTorch DataLoaders from Lhotse CutSet's -- they control the " "effective batch sizes, sampling strategies, applied data " "augmentations, etc." ), ) group.add_argument( "--max-duration", type=int, default=200.0, help=( "Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM." ), ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, help=( "When enabled, the batches will come from buckets of " "similar duration (saves padding frames)." ), ) group.add_argument( "--num-buckets", type=int, default=30, help=( "The number of buckets for the DynamicBucketingSampler. " "(you might want to increase it for larger datasets)." ), ) group.add_argument( "--shuffle", type=str2bool, default=True, help=( "When enabled (=default), the examples will be shuffled for each epoch." ), ) group.add_argument( "--return-cuts", type=str2bool, default=True, help=( "When enabled, each batch will have the " "field: batch['supervisions']['cut'] with the cuts that " "were used to construct it." ), ) group.add_argument( "--num-workers", type=int, default=2, help="The number of training dataloader workers that collect the batches.", ) group.add_argument( "--on-the-fly-num-workers", type=int, default=0, help="The number of workers for on-the-fly feature extraction", ) group.add_argument( "--enable-spec-aug", type=str2bool, default=True, help="When enabled, use SpecAugment for training dataset.", ) group.add_argument( "--spec-aug-time-warp-factor", type=int, default=80, help=( "Used only when --enable-spec-aug is True. " "It specifies the factor for time warping in SpecAugment. " "Larger values mean more warping. " "A value less than 1 means to disable time warp." ), ) group.add_argument( "--enable-musan", type=str2bool, default=True, help=( "When enabled, select noise from MUSAN and mix it" "with training dataset. " ), ) group.add_argument( "--manifest-dir", type=Path, default=Path("data/fbank"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, help=( "When enabled, use on-the-fly cut mixing and feature " "extraction. Will drop existing precomputed feature manifests " "if available. Used only in dev/test CutSet" ), ) def train_dataloaders( self, cuts_train: CutSet, on_the_fly_feats: bool, cuts_musan: Optional[CutSet] = None, ) -> DataLoader: """ Args: cuts_train: Cuts for training. cuts_musan: If not None, it is the cuts for mixing. on_the_fly_feats: True to use OnTheFlyFeatures; False to use PrecomputedFeatures. """ transforms = [] if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=2, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, ) ) else: logging.info("Disable SpecAugment") logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) # NOTE: the PerturbSpeed transform should be added only if we # remove it from data prep stage. # Add on-the-fly speed perturbation; since originally it would # have increased epoch size by 3, we will apply prob 2/3 and use # 3x more epochs. # Speed perturbation probably should come first before # concatenation, but in principle the transforms order doesn't have # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=( OnTheFlyFeatures( extractor=Fbank(FbankConfig(num_mel_bins=80)), num_workers=self.args.on_the_fly_num_workers, ) if on_the_fly_feats else PrecomputedFeatures() ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=True, ) logging.info("About to create train dataloader") train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, ) return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: transforms = [] logging.info("About to create dev dataset") if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, ) else: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, drop_last=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=2, persistent_workers=False, ) return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False, num_buckets=self.args.num_buckets, ) logging.debug("About to create test dataloader") test_dl = DataLoader( test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, ) return test_dl