also sample from distractors when using separate words in the ref text; increase the max length of substring

This commit is contained in:
marcoyang1998 2023-08-17 12:11:33 +08:00
parent 8a238317a4
commit f23882b9f6

View File

@ -272,7 +272,6 @@ def triplet_text_sampling(
}
def triplet_text_sampling_with_context_list(
texts: List[str],
pre_texts: List[str],
@ -321,7 +320,8 @@ def triplet_text_sampling_with_context_list(
assert len(texts) == len(pre_texts)
assert len(texts) == 2
context_list = context_list.lower()
if context_list is not None:
context_list = context_list.lower()
# we assume the first item to be ground truth
gt_text = texts[0]
@ -354,6 +354,7 @@ def triplet_text_sampling_with_context_list(
pre_text = transforms[i_pre_text](pre_text)
if i_text == i_pre_text:
style_text = gt_pre_text
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
@ -379,7 +380,7 @@ def get_pre_text_with_context_list(
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "":
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.5:
# correct + distractors
@ -400,6 +401,10 @@ def get_pre_text_with_context_list(
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
@ -412,6 +417,10 @@ def get_pre_text_with_context_list(
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
@ -419,7 +428,7 @@ def get_pre_text_with_context_list(
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=100)
pre_text = get_substring(text, min_len=15, max_len=150)
else:
pre_text = pre_text