mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
change the valid/test sets; only do simple normalization in the dataloader, i.e only replace full-width symbol, replace double hyphen with space
This commit is contained in:
parent
0d1cd4f595
commit
0aee07fb4c
@ -20,27 +20,150 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from dataset import PromptASRDataset
|
from lhotse import (
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
load_manifest,
|
||||||
|
load_manifest_lazy,
|
||||||
|
validate,
|
||||||
|
)
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
|
||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import (
|
||||||
from lhotse.utils import fix_random_seed
|
BatchIO,
|
||||||
from torch.utils.data import DataLoader
|
OnTheFlyFeatures,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed, ifnone
|
||||||
|
from text_normalization import (
|
||||||
|
ref_text_normalization,
|
||||||
|
replace_full_width_symbol,
|
||||||
|
simple_normalization,
|
||||||
|
)
|
||||||
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class LibriHeavyASRDataset(torch.utils.data.Dataset):
|
||||||
|
"""This is a dataset for LibriHeavy dataset"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
return_cuts: bool = False,
|
||||||
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||||
|
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||||
|
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
|
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
||||||
|
objects used to create that batch.
|
||||||
|
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
||||||
|
before converting cuts to an input representation (audio/features).
|
||||||
|
Examples: cut concatenation, noise cuts mixing, etc.
|
||||||
|
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
||||||
|
after the cuts are converted to audio/features.
|
||||||
|
Examples: normalization, SpecAugment, etc.
|
||||||
|
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
||||||
|
By default, reads pre-computed features from disk.
|
||||||
|
:param text_sampling_func: Sampling a text as transcription from a list of texts.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# Initialize the fields
|
||||||
|
self.return_cuts = return_cuts
|
||||||
|
self.cut_transforms = ifnone(cut_transforms, [])
|
||||||
|
self.input_transforms = ifnone(input_transforms, [])
|
||||||
|
self.input_strategy = input_strategy
|
||||||
|
|
||||||
|
# a text selection function
|
||||||
|
self.text_sampling_func = text_sampling_func
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||||
|
"""
|
||||||
|
Return a new batch, with the batch size automatically determined using the constraints
|
||||||
|
of max_frames and max_cuts.
|
||||||
|
"""
|
||||||
|
validate_for_asr(cuts)
|
||||||
|
|
||||||
|
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||||
|
cuts = cuts.sort_by_duration(ascending=False)
|
||||||
|
|
||||||
|
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
|
||||||
|
# the supervision boundaries.
|
||||||
|
for tnfm in self.cut_transforms:
|
||||||
|
cuts = tnfm(cuts)
|
||||||
|
|
||||||
|
# Sort the cuts again after transforms
|
||||||
|
cuts = cuts.sort_by_duration(ascending=False)
|
||||||
|
|
||||||
|
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||||
|
# Collation performs auto-padding, if necessary.
|
||||||
|
input_tpl = self.input_strategy(cuts)
|
||||||
|
if len(input_tpl) == 3:
|
||||||
|
# An input strategy with fault tolerant audio reading mode.
|
||||||
|
# "cuts" may be a subset of the original "cuts" variable,
|
||||||
|
# that only has cuts for which we succesfully read the audio.
|
||||||
|
inputs, _, cuts = input_tpl
|
||||||
|
else:
|
||||||
|
inputs, _ = input_tpl
|
||||||
|
|
||||||
|
# Get a dict of tensors that encode the positional information about supervisions
|
||||||
|
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||||
|
# "start_frame/sample" and "num_frames/samples".
|
||||||
|
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||||
|
|
||||||
|
# Apply all available transforms on the inputs, i.e. either audio or features.
|
||||||
|
# This could be feature extraction, global MVN, SpecAugment, etc.
|
||||||
|
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
||||||
|
for tnfm in self.input_transforms:
|
||||||
|
inputs = tnfm(inputs, supervision_segments=segments)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"supervisions": default_collate(
|
||||||
|
[
|
||||||
|
simple_normalization(
|
||||||
|
self.text_sampling_func(texts=supervision.texts)
|
||||||
|
)
|
||||||
|
if self.text_sampling_func is not None
|
||||||
|
else {
|
||||||
|
"text": simple_normalization(supervision.texts[0]),
|
||||||
|
}
|
||||||
|
for sequence_idx, cut in enumerate(cuts)
|
||||||
|
for supervision in cut.supervisions
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||||
|
batch["supervisions"].update(supervision_intervals)
|
||||||
|
if self.return_cuts:
|
||||||
|
batch["supervisions"]["cut"] = [
|
||||||
|
cut for cut in cuts for sup in cut.supervisions
|
||||||
|
]
|
||||||
|
|
||||||
|
has_word_alignments = all(
|
||||||
|
s.alignment is not None and "word" in s.alignment
|
||||||
|
for c in cuts
|
||||||
|
for s in c.supervisions
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -197,7 +320,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
text_sampling_func: Callable[[List[str]], str] = None,
|
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -259,7 +382,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
logging.info("Disable SpecAugment")
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = PromptASRDataset(
|
train = LibriHeavyASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -277,7 +400,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
# to be strict (e.g. could be randomized)
|
# to be strict (e.g. could be randomized)
|
||||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = PromptASRDataset(
|
train = LibriHeavyASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
@ -326,7 +449,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
def valid_dataloaders(
|
def valid_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_valid: CutSet,
|
cuts_valid: CutSet,
|
||||||
text_sampling_func: Callable[[List[str]], str] = None,
|
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
@ -338,14 +461,14 @@ class LibriHeavyAsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = PromptASRDataset(
|
validate = LibriHeavyASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = PromptASRDataset(
|
validate = LibriHeavyASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
@ -368,7 +491,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = LibriHeavyASRDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else PrecomputedFeatures(),
|
else PrecomputedFeatures(),
|
||||||
@ -391,49 +514,44 @@ class LibriHeavyAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info(f"About to get {self.args.subset} cuts")
|
logging.info(f"About to get {self.args.subset} cuts")
|
||||||
path = self.args.manifest_dir / f"librilight_cuts_{self.args.subset}.jsonl.gz"
|
|
||||||
|
path = self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz"
|
||||||
cuts_train = CutSet.from_jsonl_lazy(path)
|
cuts_train = CutSet.from_jsonl_lazy(path)
|
||||||
|
if self.args.subset == "medium":
|
||||||
|
logging.info("Getting medium subset")
|
||||||
|
path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
|
||||||
|
cuts_train += CutSet.from_jsonl_lazy(path)
|
||||||
|
elif self.args.subset == "large":
|
||||||
|
logging.info("Getting large subset")
|
||||||
|
path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
|
||||||
|
cuts_train += CutSet.from_jsonl_lazy(path)
|
||||||
|
path = self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz"
|
||||||
|
cuts_train += CutSet.from_jsonl_lazy(path)
|
||||||
|
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
cuts_valid = load_manifest_lazy(
|
cuts = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librilight_cuts_dev.jsonl.gz"
|
self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
return cuts_valid
|
return cuts
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test cuts")
|
|
||||||
cuts_valid = load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librilight_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
return cuts_valid
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_medium_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get 2000 cuts from the medium set")
|
|
||||||
cuts_medium_2k = load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librilight_cuts_medium_2000.jsonl.gz"
|
|
||||||
)
|
|
||||||
return cuts_medium_2k
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_clean_cuts(self) -> CutSet:
|
def test_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-clean cuts")
|
logging.info("About to get test-clean cuts")
|
||||||
cuts = load_manifest_lazy(
|
cuts_valid = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librilight_finetuning_clean.jsonl.gz"
|
self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz"
|
||||||
)
|
)
|
||||||
return cuts
|
return cuts_valid
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
logging.info("About to get test-other cuts")
|
||||||
cuts = load_manifest_lazy(
|
cuts_valid = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librilight_finetuning_other.jsonl.gz"
|
self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz"
|
||||||
)
|
)
|
||||||
return cuts
|
return cuts_valid
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def librispeech_test_clean_cuts(self) -> CutSet:
|
def librispeech_test_clean_cuts(self) -> CutSet:
|
||||||
@ -448,3 +566,24 @@ class LibriHeavyAsrDataModule:
|
|||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_for_asr(cuts: CutSet) -> None:
|
||||||
|
validate(cuts)
|
||||||
|
tol = 2e-3 # 1ms
|
||||||
|
for cut in cuts:
|
||||||
|
for supervision in cut.supervisions:
|
||||||
|
assert supervision.start >= -tol, (
|
||||||
|
f"Supervisions starting before the cut are not supported for ASR"
|
||||||
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Supervision start time is relative to Cut ...
|
||||||
|
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||||
|
#
|
||||||
|
# 'supervision.end' is end of supervision inside the Cut
|
||||||
|
assert supervision.end <= cut.duration + tol, (
|
||||||
|
f"Supervisions ending after the cut "
|
||||||
|
f"are not supported for ASR"
|
||||||
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user