change the valid/test sets; only do simple normalization in the dataloader, i.e only replace full-width symbol, replace double hyphen with space

This commit is contained in:
marcoyang1998 2023-07-19 11:00:07 +08:00
parent 0d1cd4f595
commit 0aee07fb4c

View File

@ -20,27 +20,150 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from dataset import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
validate,
)
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from lhotse.dataset.input_strategies import (
BatchIO,
OnTheFlyFeatures,
PrecomputedFeatures,
)
from lhotse.utils import fix_random_seed, ifnone
from text_normalization import (
ref_text_normalization,
replace_full_width_symbol,
simple_normalization,
)
from torch.utils.data.dataloader import DataLoader, default_collate
from icefall.utils import str2bool
class LibriHeavyASRDataset(torch.utils.data.Dataset):
"""This is a dataset for LibriHeavy dataset"""
def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
):
"""
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
for more details.
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
:param text_sampling_func: Sampling a text as transcription from a list of texts.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
# a text selection function
self.text_sampling_func = text_sampling_func
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
"""
validate_for_asr(cuts)
# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
# the supervision boundaries.
for tnfm in self.cut_transforms:
cuts = tnfm(cuts)
# Sort the cuts again after transforms
cuts = cuts.sort_by_duration(ascending=False)
# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl
# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
simple_normalization(
self.text_sampling_func(texts=supervision.texts)
)
if self.text_sampling_func is not None
else {
"text": simple_normalization(supervision.texts[0]),
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]
has_word_alignments = all(
s.alignment is not None and "word" in s.alignment
for c in cuts
for s in c.supervisions
)
return batch
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
@ -197,7 +320,7 @@ class LibriHeavyAsrDataModule:
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
text_sampling_func: Callable[[List[str]], str] = None,
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
) -> DataLoader:
"""
Args:
@ -259,7 +382,7 @@ class LibriHeavyAsrDataModule:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = PromptASRDataset(
train = LibriHeavyASRDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
@ -277,7 +400,7 @@ class LibriHeavyAsrDataModule:
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = PromptASRDataset(
train = LibriHeavyASRDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
@ -326,7 +449,7 @@ class LibriHeavyAsrDataModule:
def valid_dataloaders(
self,
cuts_valid: CutSet,
text_sampling_func: Callable[[List[str]], str] = None,
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
@ -338,14 +461,14 @@ class LibriHeavyAsrDataModule:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = PromptASRDataset(
validate = LibriHeavyASRDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
)
else:
validate = PromptASRDataset(
validate = LibriHeavyASRDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
@ -368,7 +491,7 @@ class LibriHeavyAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
test = LibriHeavyASRDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
@ -391,49 +514,44 @@ class LibriHeavyAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get {self.args.subset} cuts")
path = self.args.manifest_dir / f"librilight_cuts_{self.args.subset}.jsonl.gz"
path = self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz"
cuts_train = CutSet.from_jsonl_lazy(path)
if self.args.subset == "medium":
logging.info("Getting medium subset")
path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
cuts_train += CutSet.from_jsonl_lazy(path)
elif self.args.subset == "large":
logging.info("Getting large subset")
path = self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
cuts_train += CutSet.from_jsonl_lazy(path)
path = self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz"
cuts_train += CutSet.from_jsonl_lazy(path)
return cuts_train
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "librilight_cuts_dev.jsonl.gz"
cuts = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz"
)
return cuts_valid
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "librilight_cuts_test.jsonl.gz"
)
return cuts_valid
@lru_cache()
def test_medium_cuts(self) -> CutSet:
logging.info("About to get 2000 cuts from the medium set")
cuts_medium_2k = load_manifest_lazy(
self.args.manifest_dir / "librilight_cuts_medium_2000.jsonl.gz"
)
return cuts_medium_2k
return cuts
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
cuts = load_manifest_lazy(
self.args.manifest_dir / "librilight_finetuning_clean.jsonl.gz"
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz"
)
return cuts
return cuts_valid
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
cuts = load_manifest_lazy(
self.args.manifest_dir / "librilight_finetuning_other.jsonl.gz"
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz"
)
return cuts
return cuts_valid
@lru_cache()
def librispeech_test_clean_cuts(self) -> CutSet:
@ -448,3 +566,24 @@ class LibriHeavyAsrDataModule:
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)
def validate_for_asr(cuts: CutSet) -> None:
validate(cuts)
tol = 2e-3 # 1ms
for cut in cuts:
for supervision in cut.supervisions:
assert supervision.start >= -tol, (
f"Supervisions starting before the cut are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
#
# 'supervision.end' is end of supervision inside the Cut
assert supervision.end <= cut.duration + tol, (
f"Supervisions ending after the cut "
f"are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)