590 lines
21 KiB
Python

from typing import Callable, Dict, List, Optional, Union
import random
import numpy as np
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from torch.utils.data.dataloader import DataLoader, default_collate
from text_normalization import (
upper_only_alpha,
lower_only_alpha,
upper_all_char,
lower_all_char,
train_text_normalization,
)
class PromptASRDataset(torch.utils.data.Dataset):
"""This is a dataset for Prompt ASR. It supports the following features:
1. Select a tuple of (text, pre_text, style_text) randomly from a
list of texts as supervisions.
"""
def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
):
"""
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
for more details.
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
:param text_sampling_func: Sampling a text as transcription from a list of texts.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
# a text sampling function
self.text_sampling_func = text_sampling_func
def __getitem__(
self, cuts: CutSet
) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
"""
validate_for_asr(cuts)
# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
# the supervision boundaries.
for tnfm in self.cut_transforms:
cuts = tnfm(cuts)
# Sort the cuts again after transforms
cuts = cuts.sort_by_duration(ascending=False)
# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl
# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
self.text_sampling_func(
texts=supervision.texts,
pre_texts=supervision.pre_texts,
)
if self.text_sampling_func is not None
else {
"text": train_text_normalization(supervision.texts[0]),
"pre_text": train_text_normalization(
supervision.pre_texts[0]
),
"style_text": train_text_normalization(
supervision.pre_texts[0]
),
"transform_ids": 0,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]
has_word_alignments = all(
s.alignment is not None and "word" in s.alignment
for c in cuts
for s in c.supervisions
)
return batch
def validate_for_asr(cuts: CutSet) -> None:
validate(cuts)
tol = 2e-3 # 1ms
for cut in cuts:
for supervision in cut.supervisions:
assert supervision.start >= -tol, (
f"Supervisions starting before the cut are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
#
# 'supervision.end' is end of supervision inside the Cut
assert supervision.end <= cut.duration + tol, (
f"Supervisions ending after the cut "
f"are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
"""A helper function that generates a random substring from a given string
Args:
s (str): Input string
Returns:
str: Returned substring
"""
min_len = min(len(s), min_len)
start = random.randint(0, len(s) - min_len)
end = min(start + max_len, random.randint(start + min_len, len(s)))
return s[start:end]
def triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.7, 0.3, 0.0, 0.0] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_pre_text](gt_pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def multi_ref_text_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
assert len(texts) == 3
# we assume the first item to be ground truth, the third item to be the
# decoding results with prompts
if random.random() < 0.5:
gt_text = texts[0]
else:
gt_text = texts[2] # decoding res with prompt
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_pre_text](gt_pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def joint_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of pre_text, style_text
and ref_text should **always** match.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), A(style_text), A(text))
(B(pre_text), B(style_text), B(text))
(C(pre_text), C(style_text), C(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_text](gt_pre_text)
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_style_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is fixed to mixed-trans.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(gt_pre_text, B(style_text), B(text))
(gt_pre_text, C(style_text), C(text))
(gt_pre_text, A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
# get the normalized text
text = transforms[t_id](gt_text)
# get the original un-processed style text
style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(gt_pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": t_id,
}
def naive_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: str = None,
rare_word_list: List[str] = None,
min_len_style: Optional[int] = 120,
):
return {
"text": train_text_normalization(texts[0]),
"pre_text": train_text_normalization(pre_texts[0]),
#"pre_text": "",
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"pre_text": "Hello, my friend. "*50,
#"style_text": train_text_normalization(pre_texts[0][:200]),
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"style_text": "Mixed-case English transcription, with punctuation.",
#"style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
"transform_ids": 0,
}
def random_shuffle_subset(
data: List[str],
p: float = 0.2,
p_mask: float = 0.05,
) -> List[str]:
"""
Randomly shuffle the subset by probability p, which means that p% of the samples
in the original batch are shuffled, the others are kept in the original order.
With a probability of p_mask, replace the original string with an empty string.
"""
num_to_shuffle = int(len(data) * p)
id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False)
item_to_shuffle = [data[id] for id in id_to_shuffle]
random.shuffle(item_to_shuffle)
for id, item in zip(id_to_shuffle, item_to_shuffle):
data[id] = item
# Randomly mask a proportion of the data to empty string
if p_mask > 0:
for i in range(len(data)):
if random.random() < p_mask:
data[i] = ""
return data
if __name__ == "__main__":
texts = [
"AA, BB, cC, dD!",
"AA BB CC DD",
]
pre_texts = [
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
]
# for i in range(10):
# print(f"Run: {i}")
# print(triplet_text_sampling(texts, pre_texts))
import time
start = time.time()
data = [str(i) for i in range(30)]
random.shuffle(data)
print(data)
for i in range(1):
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
print(shuffled)
print((time.time() - start)/100)