mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Merge 0f88a3a6c3c4051d3f8feb20ec1a207d504e2c53 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
c8816d7930
@ -0,0 +1,415 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import lhotse
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
K2SpeechRecognitionDataset,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.speech_recognition_dataset import (
|
||||||
|
ConsistencyRegularizationSpeechRecognitionDataset,
|
||||||
|
)
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
We use c.features = None below to suppress the following warnings
|
||||||
|
|
||||||
|
2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a
|
||||||
|
DataCut that references pre-computed features. The feature manifest will be
|
||||||
|
detached, as we do not support feature-domain speed perturbation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def perturb_speed(c: Cut):
|
||||||
|
factor = random.choice([0.9, 1.1])
|
||||||
|
c.features = None
|
||||||
|
|
||||||
|
return lhotse.MonoCut.perturb_speed(c, factor)
|
||||||
|
|
||||||
|
|
||||||
|
def perturb_volume(c: Cut):
|
||||||
|
factor = random.choice([0.9, 1.1])
|
||||||
|
c.features = None
|
||||||
|
|
||||||
|
return lhotse.MonoCut.perturb_volume(c, factor)
|
||||||
|
|
||||||
|
|
||||||
|
def perturb_tempo(c: Cut):
|
||||||
|
factor = random.choice([0.9, 1.1])
|
||||||
|
|
||||||
|
c.features = None
|
||||||
|
return lhotse.MonoCut.perturb_tempo(c, factor)
|
||||||
|
|
||||||
|
|
||||||
|
class LibriSpeechAsrDataModuleWithParallelAug:
|
||||||
|
"""
|
||||||
|
DataModule for k2 ASR experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- augmentation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="ASR data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--full-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""Used only when --mini-libri is False.When enabled,
|
||||||
|
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--enable-augmentation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="True to enable augmentation for training set",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--mini-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="True for mini librispeech",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available. For training dataset, it always uses on_the_fly_feats",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
if self.args.enable_augmentation:
|
||||||
|
logging.info("Augmentation is enabled")
|
||||||
|
transforms = [perturb_speed, perturb_volume, perturb_tempo]
|
||||||
|
else:
|
||||||
|
logging.info("Augmentation is disabled")
|
||||||
|
transforms = []
|
||||||
|
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = ConsistencyRegularizationSpeechRecognitionDataset(
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
transforms = []
|
||||||
|
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
validate = K2SpeechRecognitionDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = K2SpeechRecognitionDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
valid_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create dev dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.debug("About to create test dataset")
|
||||||
|
test = K2SpeechRecognitionDataset(
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
|
if self.args.on_the_fly_feats
|
||||||
|
else eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.debug("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_5_cuts(self) -> CutSet:
|
||||||
|
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train-clean-100 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_clean_360_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train-clean-360 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_other_500_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train-other-500 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_all_shuf_cuts(self) -> CutSet:
|
||||||
|
logging.info(
|
||||||
|
"About to get the shuffled train-clean-100, \
|
||||||
|
train-clean-360 and train-other-500 cuts"
|
||||||
|
)
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_clean_2_cuts(self) -> CutSet:
|
||||||
|
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_clean_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_other_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-other cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def gigaspeech_subset_small_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get Gigaspeech subset-S cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def gigaspeech_dev_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get Gigaspeech dev cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def gigaspeech_test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get Gigaspeech test cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
|
1560
egs/librispeech/ASR/zipformer/train_with_aug.py
Executable file
1560
egs/librispeech/ASR/zipformer/train_with_aug.py
Executable file
File diff suppressed because it is too large
Load Diff
145
icefall/speech_recognition_dataset.py
Normal file
145
icefall/speech_recognition_dataset.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import validate
|
||||||
|
from lhotse.cut import Cut, CutSet
|
||||||
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||||
|
from lhotse.utils import ifnone
|
||||||
|
from lhotse.workarounds import Hdf5MemoryIssueFix
|
||||||
|
from torch.utils.data.dataloader import default_collate
|
||||||
|
|
||||||
|
|
||||||
|
class ConsistencyRegularizationSpeechRecognitionDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
return_cuts: bool = False,
|
||||||
|
cut_transforms: List[Callable[[Cut], Cut]] = None,
|
||||||
|
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.return_cuts = return_cuts
|
||||||
|
self.cut_transforms = ifnone(cut_transforms, [])
|
||||||
|
self.input_strategy = input_strategy
|
||||||
|
|
||||||
|
# This attribute is a workaround to constantly growing HDF5 memory
|
||||||
|
# throughout the epoch. It regularly closes open file handles to
|
||||||
|
# reset the internal HDF5 caches.
|
||||||
|
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dict
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
{
|
||||||
|
'inputs': float tensor with shape determined by :attr:`input_strategy`:
|
||||||
|
- single-channel:
|
||||||
|
- features: (B, T, F)
|
||||||
|
- audio: (B, T)
|
||||||
|
- multi-channel: currently not supported
|
||||||
|
'supervisions': [
|
||||||
|
'sequence_idx': Tensor[int] of shape (S,)
|
||||||
|
'text': List[str] of len S
|
||||||
|
|
||||||
|
# For feature input strategies
|
||||||
|
'start_frame': Tensor[int] of shape (S,)
|
||||||
|
'num_frames': Tensor[int] of shape (S,)
|
||||||
|
|
||||||
|
# For audio input strategies
|
||||||
|
'start_sample': Tensor[int] of shape (S,)
|
||||||
|
'num_samples': Tensor[int] of shape (S,)
|
||||||
|
|
||||||
|
# Optionally, when return_cuts=True
|
||||||
|
'cut': List[AnyCut] of len S
|
||||||
|
|
||||||
|
],
|
||||||
|
'aug': [
|
||||||
|
# it contains augmented cut info
|
||||||
|
{'inputs': xxx, 'supervisions': [xxx]},
|
||||||
|
{'inputs': xxx, 'supervisions': [xxx]},
|
||||||
|
{'inputs': xxx, 'supervisions': [xxx]},
|
||||||
|
|
||||||
|
# where xxx means it contains similar info as the non-augmented version
|
||||||
|
|
||||||
|
# aug[i] corresponds to self.cut_transforms[i]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
validate_for_asr(cuts)
|
||||||
|
self.hdf5_fix.update()
|
||||||
|
|
||||||
|
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||||
|
cuts = cuts.sort_by_duration(ascending=False)
|
||||||
|
|
||||||
|
batch = self._process(cuts)
|
||||||
|
|
||||||
|
if self.cut_transforms:
|
||||||
|
batch["aug"] = []
|
||||||
|
|
||||||
|
for i, tf in enumerate(self.cut_transforms):
|
||||||
|
transformed_cuts = cuts.map(tf)
|
||||||
|
|
||||||
|
batch["aug"].append(self._process(transformed_cuts))
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def _process(self, cuts: CutSet):
|
||||||
|
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||||
|
# Collation performs auto-padding, if necessary.
|
||||||
|
input_tpl = self.input_strategy(cuts)
|
||||||
|
if len(input_tpl) == 3:
|
||||||
|
# An input strategy with fault tolerant audio reading mode.
|
||||||
|
# "cuts" may be a subset of the original "cuts" variable,
|
||||||
|
# that only has cuts for which we successfully read the audio.
|
||||||
|
inputs, _, cuts = input_tpl
|
||||||
|
else:
|
||||||
|
inputs, _ = input_tpl
|
||||||
|
|
||||||
|
# Get a dict of tensors that encode the positional information about supervisions
|
||||||
|
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||||
|
# "start_frame/sample" and "num_frames/samples".
|
||||||
|
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"supervisions": default_collate(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"text": supervision.text,
|
||||||
|
}
|
||||||
|
for sequence_idx, cut in enumerate(cuts)
|
||||||
|
for supervision in cut.supervisions
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||||
|
batch["supervisions"].update(supervision_intervals)
|
||||||
|
if self.return_cuts:
|
||||||
|
batch["supervisions"]["cut"] = [
|
||||||
|
cut for cut in cuts for sup in cut.supervisions
|
||||||
|
]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def validate_for_asr(cuts: CutSet) -> None:
|
||||||
|
validate(cuts)
|
||||||
|
tol = 2e-3 # 1ms
|
||||||
|
for cut in cuts:
|
||||||
|
for supervision in cut.supervisions:
|
||||||
|
assert supervision.start >= -tol, (
|
||||||
|
f"Supervisions starting before the cut are not supported for ASR"
|
||||||
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Supervision start time is relative to Cut ...
|
||||||
|
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||||
|
#
|
||||||
|
# 'supervision.end' is end of supervision inside the Cut
|
||||||
|
assert supervision.end <= cut.duration + tol, (
|
||||||
|
f"Supervisions ending after the cut "
|
||||||
|
f"are not supported for ASR"
|
||||||
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user