mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
support using context list and random substring as pre text
This commit is contained in:
parent
17d0918969
commit
4420788f66
@ -269,6 +269,8 @@ def triplet_text_sampling(
|
||||
def multi_ref_text_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
@ -326,7 +328,7 @@ def multi_ref_text_triplet_text_sampling(
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.6, 0.2, 0.1, 0.1] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
@ -511,6 +513,8 @@ def triplet_style_text_sampling(
|
||||
def naive_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str = None,
|
||||
rare_word_list: List[str] = None,
|
||||
min_len_style: Optional[int] = 120,
|
||||
):
|
||||
|
||||
|
@ -34,6 +34,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||
rare_word_list: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||
@ -60,6 +61,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
|
||||
# a text sampling function
|
||||
self.text_sampling_func = text_sampling_func
|
||||
self.rare_word_list = rare_word_list
|
||||
|
||||
def __getitem__(
|
||||
self, cuts: CutSet
|
||||
@ -111,6 +113,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
texts=supervision.texts,
|
||||
pre_texts=supervision.pre_texts,
|
||||
context_list=supervision.context_list if "context_list" in supervision.custom else None,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
if self.text_sampling_func is not None
|
||||
else {
|
||||
@ -269,6 +272,161 @@ def triplet_text_sampling(
|
||||
}
|
||||
|
||||
|
||||
|
||||
def triplet_text_sampling_with_context_list(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str,
|
||||
rare_word_list: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
context_list = context_list.lower()
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = get_pre_text_with_context_list(
|
||||
text=gt_pre_text,
|
||||
pre_text=gt_pre_text,
|
||||
context_list=context_list,
|
||||
rare_words_list=rare_word_list,
|
||||
)
|
||||
pre_text = transforms[i_pre_text](pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def get_pre_text_with_context_list(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "":
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(0, 50)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.7:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.1:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=100)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def joint_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
@ -426,6 +584,7 @@ def naive_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str = None,
|
||||
rare_word_list: List[str] = None,
|
||||
min_len_style: Optional[int] = 120,
|
||||
):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user