From 075e74bcb5beb9d01c728dc200e9f863f4bfd373 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 1 Jul 2025 16:32:25 +0800 Subject: [PATCH 1/3] copy files from lhotse --- .../ASR/zipformer/speech_recognition.py | 222 ++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 egs/librispeech/ASR/zipformer/speech_recognition.py diff --git a/egs/librispeech/ASR/zipformer/speech_recognition.py b/egs/librispeech/ASR/zipformer/speech_recognition.py new file mode 100644 index 000000000..4a3520b37 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/speech_recognition.py @@ -0,0 +1,222 @@ +from typing import Callable, Dict, List, Union + +import torch +from torch.utils.data.dataloader import DataLoader, default_collate + +from lhotse import validate +from lhotse.cut import CutSet +from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures +from lhotse.utils import compute_num_frames, ifnone +from lhotse.workarounds import Hdf5MemoryIssueFix + + +class K2SpeechRecognitionDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech recognition task using k2 library. + + This dataset expects to be queried with lists of cut IDs, + for which it loads features and automatically collates/batches them. + + To use it with a PyTorch DataLoader, set ``batch_size=None`` + and provide a :class:`SimpleCutSampler` sampler. + + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'inputs': float tensor with shape determined by :attr:`input_strategy`: + - single-channel: + - features: (B, T, F) + - audio: (B, T) + - multi-channel: currently not supported + 'supervisions': [ + { + 'sequence_idx': Tensor[int] of shape (S,) + 'text': List[str] of len S + + # For feature input strategies + 'start_frame': Tensor[int] of shape (S,) + 'num_frames': Tensor[int] of shape (S,) + + # For audio input strategies + 'start_sample': Tensor[int] of shape (S,) + 'num_samples': Tensor[int] of shape (S,) + + # Optionally, when return_cuts=True + 'cut': List[AnyCut] of len S + } + ] + } + + Dimension symbols legend: + * ``B`` - batch size (number of Cuts) + * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) + * ``T`` - number of frames of the longest Cut + * ``F`` - number of features + + The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. + """ + + def __init__( + self, + return_cuts: bool = False, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + """ + k2 ASR IterableDataset constructor. + + :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut + objects used to create that batch. + :param cut_transforms: A list of transforms to be applied on each sampled batch, + before converting cuts to an input representation (audio/features). + Examples: cut concatenation, noise cuts mixing, etc. + :param input_transforms: A list of transforms to be applied on each sampled batch, + after the cuts are converted to audio/features. + Examples: normalization, SpecAugment, etc. + :param input_strategy: Converts cuts into a collated batch of audio/features. + By default, reads pre-computed features from disk. + """ + super().__init__() + # Initialize the fields + self.return_cuts = return_cuts + self.cut_transforms = ifnone(cut_transforms, []) + self.input_transforms = ifnone(input_transforms, []) + self.input_strategy = input_strategy + + # This attribute is a workaround to constantly growing HDF5 memory + # throughout the epoch. It regularly closes open file handles to + # reset the internal HDF5 caches. + self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_duration and max_cuts. + """ + validate_for_asr(cuts) + + self.hdf5_fix.update() + + # Sort the cuts by duration so that the first one determines the batch time dimensions. + cuts = cuts.sort_by_duration(ascending=False) + + # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts + # the supervision boundaries. + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + # Sort the cuts again after transforms + cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + # Apply all available transforms on the inputs, i.e. either audio or features. + # This could be feature extraction, global MVN, SpecAugment, etc. + segments = torch.stack(list(supervision_intervals.values()), dim=1) + for tnfm in self.input_transforms: + inputs = tnfm(inputs, supervision_segments=segments) + + batch = { + "inputs": inputs, + "supervisions": default_collate( + [ + { + "text": supervision.text, + } + for sequence_idx, cut in enumerate(cuts) + for supervision in cut.supervisions + ] + ), + } + # Update the 'supervisions' field with sequence_idx and start/num frames/samples + batch["supervisions"].update(supervision_intervals) + if self.return_cuts: + batch["supervisions"]["cut"] = [ + cut for cut in cuts for sup in cut.supervisions + ] + + has_word_alignments = all( + s.alignment is not None and "word" in s.alignment + for c in cuts + for s in c.supervisions + ) + if has_word_alignments: + # TODO: might need to refactor BatchIO API to move the following conditional logic + # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), + # that returns either num_frames or num_samples depending on the strategy). + words, starts, ends = [], [], [] + frame_shift = cuts[0].frame_shift + sampling_rate = cuts[0].sampling_rate + if frame_shift is None: + try: + frame_shift = self.input_strategy.extractor.frame_shift + except AttributeError: + raise ValueError( + "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " + ) + for c in cuts: + for s in c.supervisions: + words.append([aliword.symbol for aliword in s.alignment["word"]]) + starts.append( + [ + compute_num_frames( + aliword.start, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + ends.append( + [ + compute_num_frames( + aliword.end, + frame_shift=frame_shift, + sampling_rate=sampling_rate, + ) + for aliword in s.alignment["word"] + ] + ) + batch["supervisions"]["word"] = words + batch["supervisions"]["word_start"] = starts + batch["supervisions"]["word_end"] = ends + + return batch + + +def validate_for_asr(cuts: CutSet) -> None: + validate(cuts) + tol = 2e-3 # 1ms + for cut in cuts: + for supervision in cut.supervisions: + assert supervision.start >= -tol, ( + f"Supervisions starting before the cut are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + # + # 'supervision.end' is end of supervision inside the Cut + assert supervision.end <= cut.duration + tol, ( + f"Supervisions ending after the cut " + f"are not supported for ASR" + f" (sup id: {supervision.id}, cut id: {cut.id})" + ) From 85f6deb8d18c899000ff07fb51d47708eb42c8c3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 1 Jul 2025 16:58:31 +0800 Subject: [PATCH 2/3] Support using different musan augmentations for the same audio. In addition, it returns the original audio without augmentation. --- .../ASR/zipformer/speech_recognition.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/speech_recognition.py b/egs/librispeech/ASR/zipformer/speech_recognition.py index 4a3520b37..828602fcb 100644 --- a/egs/librispeech/ASR/zipformer/speech_recognition.py +++ b/egs/librispeech/ASR/zipformer/speech_recognition.py @@ -103,13 +103,15 @@ class K2SpeechRecognitionDataset(torch.utils.data.Dataset): # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) - # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts - # the supervision boundaries. - for tnfm in self.cut_transforms: - cuts = tnfm(cuts) + if self.cut_transforms: + orig_cuts = cuts - # Sort the cuts again after transforms - cuts = cuts.sort_by_duration(ascending=False) + cuts = cuts.repeat(times=2) + + for tnfm in self.cut_transforms: + cuts = tnfm(cuts) + + cuts = orig_cuts + cuts # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. @@ -117,7 +119,7 @@ class K2SpeechRecognitionDataset(torch.utils.data.Dataset): if len(input_tpl) == 3: # An input strategy with fault tolerant audio reading mode. # "cuts" may be a subset of the original "cuts" variable, - # that only has cuts for which we succesfully read the audio. + # that only has cuts for which we successfully read the audio. inputs, _, cuts = input_tpl else: inputs, _ = input_tpl From eaaab475090e23b7413c7930c2f9ae7eef136f2d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 1 Jul 2025 17:20:27 +0800 Subject: [PATCH 3/3] Fix for asr_datamodule.py --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 283252a46..2dcf090ad 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -29,7 +29,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, - K2SpeechRecognitionDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, @@ -39,6 +38,7 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples OnTheFlyFeatures, ) from lhotse.utils import fix_random_seed +from speech_recognition import K2SpeechRecognitionDataset from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -232,8 +232,11 @@ class LibriSpeechAsrDataModule: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + + # We use probability 1.0 here so that musan augmentation is + # always performed transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=1.0, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN")