mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
updates
This commit is contained in:
parent
cad01bfcb6
commit
d4c5a1c157
@ -488,7 +488,7 @@ class LibriHeavyAsrDataModule:
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz"
|
||||
self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@ -496,7 +496,7 @@ class LibriHeavyAsrDataModule:
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz"
|
||||
self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz"
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
@ -513,6 +513,20 @@ class LibriHeavyAsrDataModule:
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def npr1_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get npr1 dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def npr1_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get npr1 test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "npr1_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def long_audio_cuts(self) -> CutSet:
|
||||
|
@ -188,6 +188,98 @@ def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
|
||||
def triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def triplet_text_sampling2(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
@ -345,7 +437,7 @@ def triplet_text_sampling_with_context_list(
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = get_pre_text_with_context_list(
|
||||
pre_text = get_pre_text_with_context_list2(
|
||||
text=gt_pre_text,
|
||||
pre_text=gt_pre_text,
|
||||
context_list=context_list,
|
||||
@ -402,7 +494,7 @@ def get_pre_text_with_context_list(
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words, num_distractors)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
@ -418,7 +510,7 @@ def get_pre_text_with_context_list(
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words, num_distractors)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.2:
|
||||
@ -436,6 +528,71 @@ def get_pre_text_with_context_list(
|
||||
|
||||
|
||||
|
||||
def get_pre_text_with_context_list2(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.4:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.55:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(50,100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.3:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(50,100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.4:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.6:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def joint_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
|
Loading…
x
Reference in New Issue
Block a user