This commit is contained in:
marcoyang1998 2023-09-08 09:55:41 +08:00
parent cad01bfcb6
commit d4c5a1c157
2 changed files with 176 additions and 5 deletions

View File

@ -488,7 +488,7 @@ class LibriHeavyAsrDataModule:
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz"
self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz"
)
return cuts_valid
@ -496,7 +496,7 @@ class LibriHeavyAsrDataModule:
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz"
self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz"
)
return cuts_valid
@ -513,6 +513,20 @@ class LibriHeavyAsrDataModule:
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)
@lru_cache()
def npr1_dev_cuts(self) -> CutSet:
logging.info("About to get npr1 dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz"
)
@lru_cache()
def npr1_test_cuts(self) -> CutSet:
logging.info("About to get npr1 test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "npr1_cuts_test.jsonl.gz"
)
@lru_cache()
def long_audio_cuts(self) -> CutSet:

View File

@ -188,6 +188,98 @@ def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
def triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_pre_text](gt_pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_text_sampling2(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
@ -345,7 +437,7 @@ def triplet_text_sampling_with_context_list(
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = get_pre_text_with_context_list(
pre_text = get_pre_text_with_context_list2(
text=gt_pre_text,
pre_text=gt_pre_text,
context_list=context_list,
@ -402,7 +494,7 @@ def get_pre_text_with_context_list(
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words, num_distractors)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
@ -418,7 +510,7 @@ def get_pre_text_with_context_list(
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words, num_distractors)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.2:
@ -436,6 +528,71 @@ def get_pre_text_with_context_list(
def get_pre_text_with_context_list2(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.4:
# correct + distractors
# sample distractors
num_distractors = random.randint(50, 100)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
elif v < 0.55:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(50,100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
else:
v = random.random()
if v < 0.3:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(50,100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.4:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.6:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
pre_text = pre_text
return pre_text
def joint_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],