# Copyright 2021 Piotr Żelasko # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import inspect import logging from functools import lru_cache from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import torch from lhotse import ( CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy, validate, ) from lhotse.dataset import ( CutConcatenate, CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, SingleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( BatchIO, OnTheFlyFeatures, PrecomputedFeatures, ) from lhotse.utils import fix_random_seed, ifnone from text_normalization import replace_full_width_symbol, simple_normalization from torch.utils.data.dataloader import DataLoader, default_collate from icefall.utils import str2bool class LibriHeavyASRDataset(torch.utils.data.Dataset): """This is a dataset for LibriHeavy dataset""" def __init__( self, return_cuts: bool = False, cut_transforms: List[Callable[[CutSet], CutSet]] = None, input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: BatchIO = PrecomputedFeatures(), text_sampling_func: Optional[Callable[[List[str]], str]] = None, ): """ Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py for more details. :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut objects used to create that batch. :param cut_transforms: A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc. :param input_transforms: A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc. :param input_strategy: Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk. :param text_sampling_func: Sampling a text as transcription from a list of texts. """ super().__init__() # Initialize the fields self.return_cuts = return_cuts self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, []) self.input_strategy = input_strategy # a text selection function self.text_sampling_func = text_sampling_func def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints of max_frames and max_cuts. """ validate_for_asr(cuts) # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts # the supervision boundaries. for tnfm in self.cut_transforms: cuts = tnfm(cuts) # Sort the cuts again after transforms cuts = cuts.sort_by_duration(ascending=False) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. input_tpl = self.input_strategy(cuts) if len(input_tpl) == 3: # An input strategy with fault tolerant audio reading mode. # "cuts" may be a subset of the original "cuts" variable, # that only has cuts for which we succesfully read the audio. inputs, _, cuts = input_tpl else: inputs, _ = input_tpl # Get a dict of tensors that encode the positional information about supervisions # in the batch of feature matrices. The tensors are named "sequence_idx", # "start_frame/sample" and "num_frames/samples". supervision_intervals = self.input_strategy.supervision_intervals(cuts) # Apply all available transforms on the inputs, i.e. either audio or features. # This could be feature extraction, global MVN, SpecAugment, etc. segments = torch.stack(list(supervision_intervals.values()), dim=1) for tnfm in self.input_transforms: inputs = tnfm(inputs, supervision_segments=segments) batch = { "inputs": inputs, "supervisions": default_collate( [ simple_normalization( self.text_sampling_func(texts=supervision.texts) ) if self.text_sampling_func is not None else { "text": simple_normalization(supervision.texts[0]), } for sequence_idx, cut in enumerate(cuts) for supervision in cut.supervisions ] ), } # Update the 'supervisions' field with sequence_idx and start/num frames/samples batch["supervisions"].update(supervision_intervals) if self.return_cuts: batch["supervisions"]["cut"] = [ cut for cut in cuts for sup in cut.supervisions ] has_word_alignments = all( s.alignment is not None and "word" in s.alignment for c in cuts for s in c.supervisions ) return batch class _SeedWorkers: def __init__(self, seed: int): self.seed = seed def __call__(self, worker_id: int): fix_random_seed(self.seed + worker_id) class LibriHeavyAsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, but there can be multiple test dataloaders (e.g. LibriSpeech test-clean and test-other). It contains all the common data pipeline modules used in ASR experiments, e.g.: - dynamic batch size, - bucketing samplers, - cut concatenation, - augmentation, - on-the-fly feature extraction This class should be derived for specific corpora used in ASR tasks. """ def __init__(self, args: argparse.Namespace): self.args = args @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", description="These options are used for the preparation of " "PyTorch DataLoaders from Lhotse CutSet's -- they control the " "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) group.add_argument( "--manifest-dir", type=Path, default=Path("data/fbank"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( "--max-duration", type=int, default=200.0, help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, help="When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, help="Determines the maximum duration of a concatenated cut " "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, help="The amount of padding (in seconds) inserted between " "concatenated cuts. This padding is filled with noise when " "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, help="When enabled, use on-the-fly cut mixing and feature " "extraction. Will drop existing precomputed feature manifests " "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, help="When enabled (=default), the examples will be " "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, help="When enabled, each batch will have the " "field: batch['supervisions']['cut'] with the cuts that " "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, help="The number of training dataloader workers that " "collect the batches.", ) group.add_argument( "--enable-spec-aug", type=str2bool, default=True, help="When enabled, use SpecAugment for training dataset.", ) group.add_argument( "--spec-aug-time-warp-factor", type=int, default=80, help="Used only when --enable-spec-aug is True. " "It specifies the factor for time warping in SpecAugment. " "Larger values mean more warping. " "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, help="When enabled, select noise from MUSAN and mix it " "with training dataset. ", ) # Libriheavy specific arguments group.add_argument( "--subset", type=str, default="small", help="Select the Libriheavy subset (small|medium|large)", ) def train_dataloaders( self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, text_sampling_func: Optional[Callable[[List[str]], str]] = None, ) -> DataLoader: """ Args: cuts_train: CutSet for training. sampler_state_dict: The state dict for the training sampler. """ transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, # so that if we e.g. mix noise in, it will fill the gaps between # different utterances. transforms = [ CutConcatenate( duration_factor=self.args.duration_factor, gap=self.args.gap ) ] + transforms input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. num_frame_masks = 10 num_frame_masks_parameter = inspect.signature( SpecAugment.__init__ ).parameters["num_frame_masks"] if num_frame_masks_parameter.default == 1: num_frame_masks = 2 logging.info(f"Num frame mask: {num_frame_masks}") input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=num_frame_masks, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, ) ) else: logging.info("Disable SpecAugment") logging.info("About to create train dataset") train = LibriHeavyASRDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, ) if self.args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we # remove it from data prep stage. # Add on-the-fly speed perturbation; since originally it would # have increased epoch size by 3, we will apply prob 2/3 and use # 3x more epochs. # Speed perturbation probably should come first before # concatenation, but in principle the transforms order doesn't have # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. train = LibriHeavyASRDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, ) if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=True, ) else: logging.info("Using SingleCutSampler.") train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") if sampler_state_dict is not None: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, worker_init_fn=worker_init_fn, ) return train_dl def valid_dataloaders( self, cuts_valid: CutSet, text_sampling_func: Optional[Callable[[List[str]], str]] = None, ) -> DataLoader: transforms = [] if self.args.concatenate_cuts: transforms = [ CutConcatenate( duration_factor=self.args.duration_factor, gap=self.args.gap ) ] + transforms logging.info("About to create dev dataset") if self.args.on_the_fly_feats: validate = LibriHeavyASRDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, ) else: validate = LibriHeavyASRDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, ) valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, num_workers=2, persistent_workers=False, ) return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = LibriHeavyASRDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False, ) logging.debug("About to create test dataloader") test_dl = DataLoader( test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, ) return test_dl @lru_cache() def train_cuts(self) -> CutSet: logging.info(f"About to get {self.args.subset} cuts") path = self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" cuts_train = CutSet.from_jsonl_lazy(path) if self.args.subset == "medium": logging.info("Getting medium subset") path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" cuts_train += CutSet.from_jsonl_lazy(path) elif self.args.subset == "large": logging.info("Getting large subset") path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" cuts_train += CutSet.from_jsonl_lazy(path) path = self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" cuts_train += CutSet.from_jsonl_lazy(path) return cuts_train def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") cuts = load_manifest_lazy( self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" ) return cuts @lru_cache() def test_clean_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") cuts_valid = load_manifest_lazy( self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz" ) return cuts_valid @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") cuts_valid = load_manifest_lazy( self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz" ) return cuts_valid @lru_cache() def librispeech_test_clean_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" ) @lru_cache() def librispeech_test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) def validate_for_asr(cuts: CutSet) -> None: validate(cuts) tol = 2e-3 # 1ms for cut in cuts: for supervision in cut.supervisions: assert supervision.start >= -tol, ( f"Supervisions starting before the cut are not supported for ASR" f" (sup id: {supervision.id}, cut id: {cut.id})" ) # Supervision start time is relative to Cut ... # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html # # 'supervision.end' is end of supervision inside the Cut assert supervision.end <= cut.duration + tol, ( f"Supervisions ending after the cut " f"are not supported for ASR" f" (sup id: {supervision.id}, cut id: {cut.id})" )