diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py index 80dc8134a..9d9ecc63c 100644 --- a/egs/libriheavy/ASR/zipformer/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py @@ -20,27 +20,150 @@ import inspect import logging from functools import lru_cache from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch -from dataset import PromptASRDataset -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + load_manifest, + load_manifest_lazy, + validate, +) from lhotse.dataset import ( CutConcatenate, CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, - PrecomputedFeatures, SingleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader +from lhotse.dataset.input_strategies import ( + BatchIO, + OnTheFlyFeatures, + PrecomputedFeatures, +) +from lhotse.utils import fix_random_seed, ifnone +from text_normalization import ( + ref_text_normalization, + replace_full_width_symbol, + simple_normalization, +) +from torch.utils.data.dataloader import DataLoader, default_collate from icefall.utils import str2bool +class LibriHeavyASRDataset(torch.utils.data.Dataset): + """This is a dataset for LibriHeavy dataset""" + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + text_sampling_func: Optional[Callable[[List[str]], str]] = None, + ): + """ + Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py + for more details. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + :param text_sampling_func: Sampling a text as transcription from a list of texts. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # a text selection function + self.text_sampling_func = text_sampling_func + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_frames and max_cuts. + """ + validate_for_asr(cuts) + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "supervisions": default_collate( + [ + simple_normalization( + self.text_sampling_func(texts=supervision.texts) + ) + if self.text_sampling_func is not None + else { + "text": simple_normalization(supervision.texts[0]), + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + + return batch + + class _SeedWorkers: def __init__(self, seed: int): self.seed = seed @@ -197,7 +320,7 @@ class LibriHeavyAsrDataModule: self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, - text_sampling_func: Callable[[List[str]], str] = None, + text_sampling_func: Optional[Callable[[List[str]], str]] = None, ) -> DataLoader: """ Args: @@ -259,7 +382,7 @@ class LibriHeavyAsrDataModule: logging.info("Disable SpecAugment") logging.info("About to create train dataset") - train = PromptASRDataset( + train = LibriHeavyASRDataset( cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, @@ -277,7 +400,7 @@ class LibriHeavyAsrDataModule: # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. - train = PromptASRDataset( + train = LibriHeavyASRDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, @@ -326,7 +449,7 @@ class LibriHeavyAsrDataModule: def valid_dataloaders( self, cuts_valid: CutSet, - text_sampling_func: Callable[[List[str]], str] = None, + text_sampling_func: Optional[Callable[[List[str]], str]] = None, ) -> DataLoader: transforms = [] if self.args.concatenate_cuts: @@ -338,14 +461,14 @@ class LibriHeavyAsrDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - validate = PromptASRDataset( + validate = LibriHeavyASRDataset( cut_transforms=transforms, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, ) else: - validate = PromptASRDataset( + validate = LibriHeavyASRDataset( cut_transforms=transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, @@ -368,7 +491,7 @@ class LibriHeavyAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( + test = LibriHeavyASRDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats else PrecomputedFeatures(), @@ -391,49 +514,44 @@ class LibriHeavyAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info(f"About to get {self.args.subset} cuts") - path = self.args.manifest_dir / f"librilight_cuts_{self.args.subset}.jsonl.gz" + + path = self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz" cuts_train = CutSet.from_jsonl_lazy(path) + if self.args.subset == "medium": + logging.info("Getting medium subset") + path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" + cuts_train += CutSet.from_jsonl_lazy(path) + elif self.args.subset == "large": + logging.info("Getting large subset") + path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz" + cuts_train += CutSet.from_jsonl_lazy(path) + path = self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz" + cuts_train += CutSet.from_jsonl_lazy(path) + return cuts_train - @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "librilight_cuts_dev.jsonl.gz" + cuts = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz" ) - return cuts_valid - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "librilight_cuts_test.jsonl.gz" - ) - return cuts_valid - - @lru_cache() - def test_medium_cuts(self) -> CutSet: - logging.info("About to get 2000 cuts from the medium set") - cuts_medium_2k = load_manifest_lazy( - self.args.manifest_dir / "librilight_cuts_medium_2000.jsonl.gz" - ) - return cuts_medium_2k + return cuts @lru_cache() def test_clean_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") - cuts = load_manifest_lazy( - self.args.manifest_dir / "librilight_finetuning_clean.jsonl.gz" + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz" ) - return cuts + return cuts_valid @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") - cuts = load_manifest_lazy( - self.args.manifest_dir / "librilight_finetuning_other.jsonl.gz" + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz" ) - return cuts + return cuts_valid @lru_cache() def librispeech_test_clean_cuts(self) -> CutSet: @@ -448,3 +566,24 @@ class LibriHeavyAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + )