mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
225 lines
9.0 KiB
Python
225 lines
9.0 KiB
Python
from typing import Callable, Dict, List, Union
|
|
|
|
import torch
|
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
|
|
|
from lhotse import validate
|
|
from lhotse.cut import CutSet
|
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
|
from lhotse.utils import compute_num_frames, ifnone
|
|
from lhotse.workarounds import Hdf5MemoryIssueFix
|
|
|
|
|
|
class K2SpeechRecognitionDataset(torch.utils.data.Dataset):
|
|
"""
|
|
The PyTorch Dataset for the speech recognition task using k2 library.
|
|
|
|
This dataset expects to be queried with lists of cut IDs,
|
|
for which it loads features and automatically collates/batches them.
|
|
|
|
To use it with a PyTorch DataLoader, set ``batch_size=None``
|
|
and provide a :class:`SimpleCutSampler` sampler.
|
|
|
|
Each item in this dataset is a dict of:
|
|
|
|
.. code-block::
|
|
|
|
{
|
|
'inputs': float tensor with shape determined by :attr:`input_strategy`:
|
|
- single-channel:
|
|
- features: (B, T, F)
|
|
- audio: (B, T)
|
|
- multi-channel: currently not supported
|
|
'supervisions': [
|
|
{
|
|
'sequence_idx': Tensor[int] of shape (S,)
|
|
'text': List[str] of len S
|
|
|
|
# For feature input strategies
|
|
'start_frame': Tensor[int] of shape (S,)
|
|
'num_frames': Tensor[int] of shape (S,)
|
|
|
|
# For audio input strategies
|
|
'start_sample': Tensor[int] of shape (S,)
|
|
'num_samples': Tensor[int] of shape (S,)
|
|
|
|
# Optionally, when return_cuts=True
|
|
'cut': List[AnyCut] of len S
|
|
}
|
|
]
|
|
}
|
|
|
|
Dimension symbols legend:
|
|
* ``B`` - batch size (number of Cuts)
|
|
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
|
|
* ``T`` - number of frames of the longest Cut
|
|
* ``F`` - number of features
|
|
|
|
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
return_cuts: bool = False,
|
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
|
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
input_strategy: BatchIO = PrecomputedFeatures(),
|
|
):
|
|
"""
|
|
k2 ASR IterableDataset constructor.
|
|
|
|
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
|
objects used to create that batch.
|
|
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
|
before converting cuts to an input representation (audio/features).
|
|
Examples: cut concatenation, noise cuts mixing, etc.
|
|
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
|
after the cuts are converted to audio/features.
|
|
Examples: normalization, SpecAugment, etc.
|
|
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
|
By default, reads pre-computed features from disk.
|
|
"""
|
|
super().__init__()
|
|
# Initialize the fields
|
|
self.return_cuts = return_cuts
|
|
self.cut_transforms = ifnone(cut_transforms, [])
|
|
self.input_transforms = ifnone(input_transforms, [])
|
|
self.input_strategy = input_strategy
|
|
|
|
# This attribute is a workaround to constantly growing HDF5 memory
|
|
# throughout the epoch. It regularly closes open file handles to
|
|
# reset the internal HDF5 caches.
|
|
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
|
|
|
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
|
"""
|
|
Return a new batch, with the batch size automatically determined using the constraints
|
|
of max_duration and max_cuts.
|
|
"""
|
|
validate_for_asr(cuts)
|
|
|
|
self.hdf5_fix.update()
|
|
|
|
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
|
cuts = cuts.sort_by_duration(ascending=False)
|
|
|
|
if self.cut_transforms:
|
|
orig_cuts = cuts
|
|
|
|
cuts = cuts.repeat(times=2)
|
|
|
|
for tnfm in self.cut_transforms:
|
|
cuts = tnfm(cuts)
|
|
|
|
cuts = orig_cuts + cuts
|
|
|
|
# Get a tensor with batched feature matrices, shape (B, T, F)
|
|
# Collation performs auto-padding, if necessary.
|
|
input_tpl = self.input_strategy(cuts)
|
|
if len(input_tpl) == 3:
|
|
# An input strategy with fault tolerant audio reading mode.
|
|
# "cuts" may be a subset of the original "cuts" variable,
|
|
# that only has cuts for which we successfully read the audio.
|
|
inputs, _, cuts = input_tpl
|
|
else:
|
|
inputs, _ = input_tpl
|
|
|
|
# Get a dict of tensors that encode the positional information about supervisions
|
|
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
|
# "start_frame/sample" and "num_frames/samples".
|
|
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
|
|
|
# Apply all available transforms on the inputs, i.e. either audio or features.
|
|
# This could be feature extraction, global MVN, SpecAugment, etc.
|
|
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
|
for tnfm in self.input_transforms:
|
|
inputs = tnfm(inputs, supervision_segments=segments)
|
|
|
|
batch = {
|
|
"inputs": inputs,
|
|
"supervisions": default_collate(
|
|
[
|
|
{
|
|
"text": supervision.text,
|
|
}
|
|
for sequence_idx, cut in enumerate(cuts)
|
|
for supervision in cut.supervisions
|
|
]
|
|
),
|
|
}
|
|
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
|
batch["supervisions"].update(supervision_intervals)
|
|
if self.return_cuts:
|
|
batch["supervisions"]["cut"] = [
|
|
cut for cut in cuts for sup in cut.supervisions
|
|
]
|
|
|
|
has_word_alignments = all(
|
|
s.alignment is not None and "word" in s.alignment
|
|
for c in cuts
|
|
for s in c.supervisions
|
|
)
|
|
if has_word_alignments:
|
|
# TODO: might need to refactor BatchIO API to move the following conditional logic
|
|
# into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
|
|
# that returns either num_frames or num_samples depending on the strategy).
|
|
words, starts, ends = [], [], []
|
|
frame_shift = cuts[0].frame_shift
|
|
sampling_rate = cuts[0].sampling_rate
|
|
if frame_shift is None:
|
|
try:
|
|
frame_shift = self.input_strategy.extractor.frame_shift
|
|
except AttributeError:
|
|
raise ValueError(
|
|
"Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
|
|
)
|
|
for c in cuts:
|
|
for s in c.supervisions:
|
|
words.append([aliword.symbol for aliword in s.alignment["word"]])
|
|
starts.append(
|
|
[
|
|
compute_num_frames(
|
|
aliword.start,
|
|
frame_shift=frame_shift,
|
|
sampling_rate=sampling_rate,
|
|
)
|
|
for aliword in s.alignment["word"]
|
|
]
|
|
)
|
|
ends.append(
|
|
[
|
|
compute_num_frames(
|
|
aliword.end,
|
|
frame_shift=frame_shift,
|
|
sampling_rate=sampling_rate,
|
|
)
|
|
for aliword in s.alignment["word"]
|
|
]
|
|
)
|
|
batch["supervisions"]["word"] = words
|
|
batch["supervisions"]["word_start"] = starts
|
|
batch["supervisions"]["word_end"] = ends
|
|
|
|
return batch
|
|
|
|
|
|
def validate_for_asr(cuts: CutSet) -> None:
|
|
validate(cuts)
|
|
tol = 2e-3 # 1ms
|
|
for cut in cuts:
|
|
for supervision in cut.supervisions:
|
|
assert supervision.start >= -tol, (
|
|
f"Supervisions starting before the cut are not supported for ASR"
|
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
|
)
|
|
|
|
# Supervision start time is relative to Cut ...
|
|
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
|
#
|
|
# 'supervision.end' is end of supervision inside the Cut
|
|
assert supervision.end <= cut.duration + tol, (
|
|
f"Supervisions ending after the cut "
|
|
f"are not supported for ASR"
|
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
|
)
|