From d4c5a1c157845d1a5079470d91181c3942a852bb Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Fri, 8 Sep 2023 09:55:41 +0800 Subject: [PATCH] updates --- .../zipformer_prompt_asr/asr_datamodule.py | 18 +- .../ASR/zipformer_prompt_asr/dataset2.py | 163 +++++++++++++++++- 2 files changed, 176 insertions(+), 5 deletions(-) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 0c082c801..80faba038 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -488,7 +488,7 @@ class LibriHeavyAsrDataModule: def test_clean_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test-clean.jsonl.gz" + self.args.manifest_dir / "libriheavy_cuts_test-clean_official.jsonl.gz" ) return cuts_valid @@ -496,7 +496,7 @@ class LibriHeavyAsrDataModule: def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") cuts_valid = load_manifest_lazy( - self.args.manifest_dir / "libriheavy_cuts_test-other.jsonl.gz" + self.args.manifest_dir / "libriheavy_cuts_test-other_official.jsonl.gz" ) return cuts_valid @@ -513,6 +513,20 @@ class LibriHeavyAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) + + @lru_cache() + def npr1_dev_cuts(self) -> CutSet: + logging.info("About to get npr1 dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def npr1_test_cuts(self) -> CutSet: + logging.info("About to get npr1 test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "npr1_cuts_test.jsonl.gz" + ) @lru_cache() def long_audio_cuts(self) -> CutSet: diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset2.py b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset2.py index 0a51b88b4..1dce78a5e 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset2.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset2.py @@ -188,6 +188,98 @@ def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str: def triplet_text_sampling( texts: List[str], pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, + transforms: Optional[List[Callable[[str], str]]] = None, + min_len_style: Optional[int] = 80, +) -> Dict[str, str]: + """This function generates a triplet of + (pre_text, style_text, ref_text). The style of style_text and ref_text + should always match, whereas the style of pre_text is arbitrary. + Suppose we have 3 different transforms A,B,C, and the groundtruth + text and pre_text are referred to as text and pre_text. + The following three tuples are all valid: + + (A(pre_text), B(style_text), B(text)) + (A(pre_text), C(style_text), C(text)) + (A(pre_text), A(style_text), A(text)) + ... + + If transforms is not given, the following pre-defined transforms + are available: + 0: original (normal case, with punc) + 1: recog (upper, no punc) + 2: upper_only_alpha (upper, no punc) + 3: lower_only_alpha (lower, no punc) + 4: upper_all (upper, with punc) + 5: lower_all (lower, with punc) + + When the transform of text and pre_text match, we can use the whole + pre_text as the prompt text. + + Args: + texts (List[str]): + A list of ref_texts whose first item is the ground truth + text from books. + pre_texts (List[str]): + A list of pre_texts, whose first item is the groundtruth + pre_text from books. + transforms (List[Callable[[str], str]]): A list of possible transforms to be applied + + Returns: + str: A dictionary + """ + # import pdb; pdb.set_trace() + assert len(texts) == len(pre_texts) + assert len(texts) == 2 + + # we assume the first item to be ground truth + gt_text = texts[0] + gt_pre_text = pre_texts[0] + + if transforms is None: + transforms = [ + lambda x: x, # return it self + upper_only_alpha, + lower_only_alpha, + lower_all_char, + ] + + # sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob + sampling_weight = [0.7, 0.3, 0.0, 0.0] + + total_transforms = len(transforms) # do not use the recognized trans + + # Select a transformation randomly + i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) + + # get the normalized text and pre_text + text = transforms[i_text](gt_text) + pre_text = transforms[i_pre_text](gt_pre_text) + + if i_text == i_pre_text: + style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) + else: + # get the pre_text of same style as text + # For now, do not do transform to the style text + style_text = gt_pre_text + # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) + style_text = get_substring(style_text, min_len=min_len_style, max_len=150) + + return { + "text": train_text_normalization(text), + "pre_text": train_text_normalization(pre_text), + "style_text": train_text_normalization(style_text), + "transform_ids": i_text, + } + + + +def triplet_text_sampling2( + texts: List[str], + pre_texts: List[str], + context_list: Optional[str] = None, + rare_word_list: Optional[List[str]] = None, transforms: Optional[List[Callable[[str], str]]] = None, min_len_style: Optional[int] = 80, ) -> Dict[str, str]: @@ -345,7 +437,7 @@ def triplet_text_sampling_with_context_list( # get the normalized text and pre_text text = transforms[i_text](gt_text) - pre_text = get_pre_text_with_context_list( + pre_text = get_pre_text_with_context_list2( text=gt_pre_text, pre_text=gt_pre_text, context_list=context_list, @@ -402,7 +494,7 @@ def get_pre_text_with_context_list( i = random.randint(1, min(len(splitted), 20)) splitted = list(np.random.choice(splitted, i, p=sampling_weights)) num_distractors = random.randint(0,70) - distractors = random.sample(rare_words, num_distractors) + distractors = random.sample(rare_words_list, num_distractors) splitted += distractors random.shuffle(splitted) # shuffle the list pre_text = " ".join(splitted) @@ -418,7 +510,7 @@ def get_pre_text_with_context_list( splitted = list(np.random.choice(splitted, i, p=sampling_weights)) pre_text = " ".join(splitted) num_distractors = random.randint(0,70) - distractors = random.sample(rare_words, num_distractors) + distractors = random.sample(rare_words_list, num_distractors) splitted += distractors random.shuffle(splitted) # shuffle the list elif v < 0.2: @@ -436,6 +528,71 @@ def get_pre_text_with_context_list( +def get_pre_text_with_context_list2( + text: str, + pre_text: str, + context_list: str, + rare_words_list: List[str] = None, +) -> str: + # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha + # By a small proportion of time, use the substring of ref_text as pre_text + + if context_list != "" and context_list is not None: + v = random.random() + if v < 0.4: + # correct + distractors + # sample distractors + num_distractors = random.randint(50, 100) + distractors = random.sample(rare_words_list, num_distractors) + # sample correct + correct = context_list.split() + i = random.randint(1, len(correct)) + correct = random.sample(correct, i) + # combine correct and distractors + pre_text = distractors + correct + random.shuffle(pre_text) + pre_text = " ".join(pre_text) + elif v < 0.55: + splitted = text.split() + sampling_weights = [len(w)**1.2 for w in splitted] + sampling_weights = [p/sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + num_distractors = random.randint(50,100) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + pre_text = " ".join(splitted) + else: + pre_text = pre_text + else: + v = random.random() + if v < 0.3: + splitted = text.split() + sampling_weights = [len(w)**1.2 for w in splitted] + sampling_weights = [p/sum(sampling_weights) for p in sampling_weights] + i = random.randint(1, min(len(splitted), 20)) + splitted = list(np.random.choice(splitted, i, p=sampling_weights)) + pre_text = " ".join(splitted) + num_distractors = random.randint(50,100) + distractors = random.sample(rare_words_list, num_distractors) + splitted += distractors + random.shuffle(splitted) # shuffle the list + elif v < 0.4: + # full distractors + num_distractors = random.randint(5, 100) + distractors = random.sample(rare_words_list, num_distractors) + pre_text = " ".join(distractors) + + elif v < 0.6: + pre_text = get_substring(text, min_len=15, max_len=150) + else: + pre_text = pre_text + + return pre_text + + + def joint_triplet_text_sampling( texts: List[str], pre_texts: List[str],