mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
also sample from distractors when using separate words in the ref text; increase the max length of substring
This commit is contained in:
parent
8a238317a4
commit
f23882b9f6
@ -272,7 +272,6 @@ def triplet_text_sampling(
|
||||
}
|
||||
|
||||
|
||||
|
||||
def triplet_text_sampling_with_context_list(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
@ -321,7 +320,8 @@ def triplet_text_sampling_with_context_list(
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
context_list = context_list.lower()
|
||||
if context_list is not None:
|
||||
context_list = context_list.lower()
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
@ -354,6 +354,7 @@ def triplet_text_sampling_with_context_list(
|
||||
pre_text = transforms[i_pre_text](pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = gt_pre_text
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
@ -379,7 +380,7 @@ def get_pre_text_with_context_list(
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "":
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
# correct + distractors
|
||||
@ -400,6 +401,10 @@ def get_pre_text_with_context_list(
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
@ -412,6 +417,10 @@ def get_pre_text_with_context_list(
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
@ -419,7 +428,7 @@ def get_pre_text_with_context_list(
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=100)
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user