support using context list and random substring as pre text

This commit is contained in:
marcoyang1998 2023-08-16 16:44:29 +08:00
parent 17d0918969
commit 4420788f66
2 changed files with 164 additions and 1 deletions

View File

@ -269,6 +269,8 @@ def triplet_text_sampling(
def multi_ref_text_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
@ -326,7 +328,7 @@ def multi_ref_text_triplet_text_sampling(
lower_all_char,
]
sampling_weight = [0.6, 0.2, 0.1, 0.1] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
@ -511,6 +513,8 @@ def triplet_style_text_sampling(
def naive_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: str = None,
rare_word_list: List[str] = None,
min_len_style: Optional[int] = 120,
):

View File

@ -34,6 +34,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
rare_word_list: Optional[List[str]] = None
):
"""
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
@ -60,6 +61,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
# a text sampling function
self.text_sampling_func = text_sampling_func
self.rare_word_list = rare_word_list
def __getitem__(
self, cuts: CutSet
@ -111,6 +113,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
texts=supervision.texts,
pre_texts=supervision.pre_texts,
context_list=supervision.context_list if "context_list" in supervision.custom else None,
rare_word_list=self.rare_word_list,
)
if self.text_sampling_func is not None
else {
@ -269,6 +272,161 @@ def triplet_text_sampling(
}
def triplet_text_sampling_with_context_list(
texts: List[str],
pre_texts: List[str],
context_list: str,
rare_word_list: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
context_list = context_list.lower()
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = get_pre_text_with_context_list(
text=gt_pre_text,
pre_text=gt_pre_text,
context_list=context_list,
rare_words_list=rare_word_list,
)
pre_text = transforms[i_pre_text](pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def get_pre_text_with_context_list(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "":
v = random.random()
if v < 0.5:
# correct + distractors
# sample distractors
num_distractors = random.randint(0, 50)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
elif v < 0.7:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
else:
pre_text = pre_text
else:
v = random.random()
if v < 0.1:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=100)
else:
pre_text = pre_text
return pre_text
def joint_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
@ -426,6 +584,7 @@ def naive_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: str = None,
rare_word_list: List[str] = None,
min_len_style: Optional[int] = 120,
):