mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Merge 0f88a3a6c3c4051d3f8feb20ec1a207d504e2c53 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
c8816d7930
@ -0,0 +1,415 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import lhotse
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.speech_recognition_dataset import (
|
||||
ConsistencyRegularizationSpeechRecognitionDataset,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
"""
|
||||
We use c.features = None below to suppress the following warnings
|
||||
|
||||
2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a
|
||||
DataCut that references pre-computed features. The feature manifest will be
|
||||
detached, as we do not support feature-domain speed perturbation.
|
||||
"""
|
||||
|
||||
|
||||
def perturb_speed(c: Cut):
|
||||
factor = random.choice([0.9, 1.1])
|
||||
c.features = None
|
||||
|
||||
return lhotse.MonoCut.perturb_speed(c, factor)
|
||||
|
||||
|
||||
def perturb_volume(c: Cut):
|
||||
factor = random.choice([0.9, 1.1])
|
||||
c.features = None
|
||||
|
||||
return lhotse.MonoCut.perturb_volume(c, factor)
|
||||
|
||||
|
||||
def perturb_tempo(c: Cut):
|
||||
factor = random.choice([0.9, 1.1])
|
||||
|
||||
c.features = None
|
||||
return lhotse.MonoCut.perturb_tempo(c, factor)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModuleWithParallelAug:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""Used only when --mini-libri is False.When enabled,
|
||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--enable-augmentation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="True to enable augmentation for training set",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mini-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True for mini librispeech",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available. For training dataset, it always uses on_the_fly_feats",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
if self.args.enable_augmentation:
|
||||
logging.info("Augmentation is enabled")
|
||||
transforms = [perturb_speed, perturb_volume, perturb_tempo]
|
||||
else:
|
||||
logging.info("Augmentation is disabled")
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = ConsistencyRegularizationSpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_5_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_2_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_subset_small_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech subset-S cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def gigaspeech_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get Gigaspeech test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
|
1560
egs/librispeech/ASR/zipformer/train_with_aug.py
Executable file
1560
egs/librispeech/ASR/zipformer/train_with_aug.py
Executable file
File diff suppressed because it is too large
Load Diff
145
icefall/speech_recognition_dataset.py
Normal file
145
icefall/speech_recognition_dataset.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from lhotse import validate
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import ifnone
|
||||
from lhotse.workarounds import Hdf5MemoryIssueFix
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class ConsistencyRegularizationSpeechRecognitionDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
return_cuts: bool = False,
|
||||
cut_transforms: List[Callable[[Cut], Cut]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
):
|
||||
super().__init__()
|
||||
self.return_cuts = return_cuts
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.input_strategy = input_strategy
|
||||
|
||||
# This attribute is a workaround to constantly growing HDF5 memory
|
||||
# throughout the epoch. It regularly closes open file handles to
|
||||
# reset the internal HDF5 caches.
|
||||
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> dict:
|
||||
"""
|
||||
Return a dict
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'inputs': float tensor with shape determined by :attr:`input_strategy`:
|
||||
- single-channel:
|
||||
- features: (B, T, F)
|
||||
- audio: (B, T)
|
||||
- multi-channel: currently not supported
|
||||
'supervisions': [
|
||||
'sequence_idx': Tensor[int] of shape (S,)
|
||||
'text': List[str] of len S
|
||||
|
||||
# For feature input strategies
|
||||
'start_frame': Tensor[int] of shape (S,)
|
||||
'num_frames': Tensor[int] of shape (S,)
|
||||
|
||||
# For audio input strategies
|
||||
'start_sample': Tensor[int] of shape (S,)
|
||||
'num_samples': Tensor[int] of shape (S,)
|
||||
|
||||
# Optionally, when return_cuts=True
|
||||
'cut': List[AnyCut] of len S
|
||||
|
||||
],
|
||||
'aug': [
|
||||
# it contains augmented cut info
|
||||
{'inputs': xxx, 'supervisions': [xxx]},
|
||||
{'inputs': xxx, 'supervisions': [xxx]},
|
||||
{'inputs': xxx, 'supervisions': [xxx]},
|
||||
|
||||
# where xxx means it contains similar info as the non-augmented version
|
||||
|
||||
# aug[i] corresponds to self.cut_transforms[i]
|
||||
]
|
||||
}
|
||||
"""
|
||||
validate_for_asr(cuts)
|
||||
self.hdf5_fix.update()
|
||||
|
||||
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
batch = self._process(cuts)
|
||||
|
||||
if self.cut_transforms:
|
||||
batch["aug"] = []
|
||||
|
||||
for i, tf in enumerate(self.cut_transforms):
|
||||
transformed_cuts = cuts.map(tf)
|
||||
|
||||
batch["aug"].append(self._process(transformed_cuts))
|
||||
|
||||
return batch
|
||||
|
||||
def _process(self, cuts: CutSet):
|
||||
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||
# Collation performs auto-padding, if necessary.
|
||||
input_tpl = self.input_strategy(cuts)
|
||||
if len(input_tpl) == 3:
|
||||
# An input strategy with fault tolerant audio reading mode.
|
||||
# "cuts" may be a subset of the original "cuts" variable,
|
||||
# that only has cuts for which we successfully read the audio.
|
||||
inputs, _, cuts = input_tpl
|
||||
else:
|
||||
inputs, _ = input_tpl
|
||||
|
||||
# Get a dict of tensors that encode the positional information about supervisions
|
||||
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||
# "start_frame/sample" and "num_frames/samples".
|
||||
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"supervisions": default_collate(
|
||||
[
|
||||
{
|
||||
"text": supervision.text,
|
||||
}
|
||||
for sequence_idx, cut in enumerate(cuts)
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||
batch["supervisions"].update(supervision_intervals)
|
||||
if self.return_cuts:
|
||||
batch["supervisions"]["cut"] = [
|
||||
cut for cut in cuts for sup in cut.supervisions
|
||||
]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_asr(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
tol = 2e-3 # 1ms
|
||||
for cut in cuts:
|
||||
for supervision in cut.supervisions:
|
||||
assert supervision.start >= -tol, (
|
||||
f"Supervisions starting before the cut are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
# Supervision start time is relative to Cut ...
|
||||
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||
#
|
||||
# 'supervision.end' is end of supervision inside the Cut
|
||||
assert supervision.end <= cut.duration + tol, (
|
||||
f"Supervisions ending after the cut "
|
||||
f"are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user