diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py index 1dce78a5e..9602e9270 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/dataset.py @@ -1,23 +1,38 @@ -from typing import Callable, Dict, List, Optional, Union -import random -import numpy as np +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from typing import Callable, Dict, List, Optional, Union + +import numpy as np import torch from lhotse import validate from lhotse.cut import CutSet from lhotse.dataset import K2SpeechRecognitionDataset from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures from lhotse.utils import compute_num_frames, ifnone -from torch.utils.data.dataloader import DataLoader, default_collate - from text_normalization import ( - remove_non_alphabetic, - upper_only_alpha, - lower_only_alpha, - upper_all_char, lower_all_char, + lower_only_alpha, + remove_non_alphabetic, train_text_normalization, + upper_all_char, + upper_only_alpha, ) +from torch.utils.data.dataloader import DataLoader, default_collate class PromptASRDataset(torch.utils.data.Dataset): @@ -34,7 +49,7 @@ class PromptASRDataset(torch.utils.data.Dataset): input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: BatchIO = PrecomputedFeatures(), text_sampling_func: Optional[Callable[[List[str]], str]] = None, - rare_word_list: Optional[List[str]] = None + rare_word_list: Optional[List[str]] = None, ): """ Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py @@ -63,9 +78,7 @@ class PromptASRDataset(torch.utils.data.Dataset): self.text_sampling_func = text_sampling_func self.rare_word_list = rare_word_list - def __getitem__( - self, cuts: CutSet - ) -> Dict[str, Union[torch.Tensor, List[str]]]: + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints of max_frames and max_cuts. @@ -112,15 +125,15 @@ class PromptASRDataset(torch.utils.data.Dataset): self.text_sampling_func( texts=supervision.texts, pre_texts=supervision.pre_texts, - context_list=supervision.context_list if "context_list" in supervision.custom else None, + context_list=supervision.context_list + if "context_list" in supervision.custom + else None, rare_word_list=self.rare_word_list, ) if self.text_sampling_func is not None else { "text": train_text_normalization(supervision.texts[0]), - "pre_text": train_text_normalization( - supervision.pre_texts[0] - ), + "pre_text": train_text_normalization(supervision.pre_texts[0]), "style_text": train_text_normalization( supervision.pre_texts[0] ), @@ -192,27 +205,22 @@ def triplet_text_sampling( rare_word_list: Optional[List[str]] = None, transforms: Optional[List[Callable[[str], str]]] = None, min_len_style: Optional[int] = 80, -) -> Dict[str, str]: +) -> Dict[str, str, str]: """This function generates a triplet of (pre_text, style_text, ref_text). The style of style_text and ref_text - should always match, whereas the style of pre_text is arbitrary. - Suppose we have 3 different transforms A,B,C, and the groundtruth - text and pre_text are referred to as text and pre_text. - The following three tuples are all valid: + should **always** match, whereas the style of pre_text is arbitrary. + Suppose we have 2 different transforms A,B, and the preceding text is + referred to as pre_text. The following three tuples are all valid: - (A(pre_text), B(style_text), B(text)) - (A(pre_text), C(style_text), C(text)) - (A(pre_text), A(style_text), A(text)) - ... + (A(pre_text), A(style_text), A(ref_text)) + (A(pre_text), B(style_text), B(ref_text)) + (A(pre_text), A(style_text), A(ref_text)) + (B(pre_text), B(style_text), B(ref_text)) If transforms is not given, the following pre-defined transforms are available: - 0: original (normal case, with punc) - 1: recog (upper, no punc) - 2: upper_only_alpha (upper, no punc) - 3: lower_only_alpha (lower, no punc) - 4: upper_all (upper, with punc) - 5: lower_all (lower, with punc) + 0: original (mixed-cased, with punc) + 1: upper_only_alpha (upper-cased, no punc) When the transform of text and pre_text match, we can use the whole pre_text as the prompt text. @@ -224,12 +232,15 @@ def triplet_text_sampling( pre_texts (List[str]): A list of pre_texts, whose first item is the groundtruth pre_text from books. + context_list: Optional[str] = None, + A list of biasing words separated by space + rare_word_list: Optional[str] = None, + A list of rare-words separated by space (used as distractors) transforms (List[Callable[[str], str]]): A list of possible transforms to be applied Returns: - str: A dictionary + A dictionary of ref_text, pre_text, style_text """ - # import pdb; pdb.set_trace() assert len(texts) == len(pre_texts) assert len(texts) == 2 @@ -244,13 +255,17 @@ def triplet_text_sampling( lower_only_alpha, lower_all_char, ] - - # sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob - sampling_weight = [0.7, 0.3, 0.0, 0.0] + + sampling_weight = [ + 0.7, + 0.3, + 0.0, + 0.0, + ] # Mixed-punc should have the largest sampling prob total_transforms = len(transforms) # do not use the recognized trans - # Select a transformation randomly + # Randomly sample transforms i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) # get the normalized text and pre_text @@ -261,97 +276,7 @@ def triplet_text_sampling( style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) else: # get the pre_text of same style as text - # For now, do not do transform to the style text - style_text = gt_pre_text - # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) - style_text = get_substring(style_text, min_len=min_len_style, max_len=150) - - return { - "text": train_text_normalization(text), - "pre_text": train_text_normalization(pre_text), - "style_text": train_text_normalization(style_text), - "transform_ids": i_text, - } - - - -def triplet_text_sampling2( - texts: List[str], - pre_texts: List[str], - context_list: Optional[str] = None, - rare_word_list: Optional[List[str]] = None, - transforms: Optional[List[Callable[[str], str]]] = None, - min_len_style: Optional[int] = 80, -) -> Dict[str, str]: - """This function generates a triplet of - (pre_text, style_text, ref_text). The style of style_text and ref_text - should always match, whereas the style of pre_text is arbitrary. - Suppose we have 3 different transforms A,B,C, and the groundtruth - text and pre_text are referred to as text and pre_text. - The following three tuples are all valid: - - (A(pre_text), B(style_text), B(text)) - (A(pre_text), C(style_text), C(text)) - (A(pre_text), A(style_text), A(text)) - ... - - If transforms is not given, the following pre-defined transforms - are available: - 0: original (normal case, with punc) - 1: recog (upper, no punc) - 2: upper_only_alpha (upper, no punc) - 3: lower_only_alpha (lower, no punc) - 4: upper_all (upper, with punc) - 5: lower_all (lower, with punc) - - When the transform of text and pre_text match, we can use the whole - pre_text as the prompt text. - - Args: - texts (List[str]): - A list of ref_texts whose first item is the ground truth - text from books. - pre_texts (List[str]): - A list of pre_texts, whose first item is the groundtruth - pre_text from books. - transforms (List[Callable[[str], str]]): A list of possible transforms to be applied - - Returns: - str: A dictionary - """ - # import pdb; pdb.set_trace() - assert len(texts) == len(pre_texts) - assert len(texts) == 2 - - # we assume the first item to be ground truth - gt_text = texts[0] - gt_pre_text = pre_texts[0] - - if transforms is None: - transforms = [ - lambda x: x, # return it self - upper_only_alpha, - lower_only_alpha, - lower_all_char, - ] - - # sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob - sampling_weight = [0.7, 0.3, 0.0, 0.0] - - total_transforms = len(transforms) # do not use the recognized trans - - # Select a transformation randomly - i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) - - # get the normalized text and pre_text - text = transforms[i_text](gt_text) - pre_text = transforms[i_pre_text](gt_pre_text) - - if i_text == i_pre_text: - style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) - else: - # get the pre_text of same style as text - # For now, do not do transform to the style text + # For now, **don't** do transform to the style text, because we do it after the dataloader style_text = gt_pre_text # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) style_text = get_substring(style_text, min_len=min_len_style, max_len=150) @@ -373,25 +298,22 @@ def triplet_text_sampling_with_context_list( min_len_style: Optional[int] = 80, ) -> Dict[str, str]: """This function generates a triplet of - (pre_text, style_text, ref_text). The style of style_text and ref_text - should always match, whereas the style of pre_text is arbitrary. - Suppose we have 3 different transforms A,B,C, and the groundtruth - text and pre_text are referred to as text and pre_text. - The following three tuples are all valid: + (pre_text, style_text, ref_text). The pre_text is either the preceding text + or a list of words (context words + distractors). + The style of style_text and ref_text should **always** match, whereas + the style of pre_text is arbitrary. + Suppose we have 2 different transforms A,B, and the preceding text is + referred to as pre_text. The following three tuples are all valid: - (A(pre_text), B(style_text), B(text)) - (A(pre_text), C(style_text), C(text)) - (A(pre_text), A(style_text), A(text)) - ... + (A(pre_text), A(style_text), A(ref_text)) + (A(pre_text), B(style_text), B(ref_text)) + (A(pre_text), A(style_text), A(ref_text)) + (B(pre_text), B(style_text), B(ref_text)) If transforms is not given, the following pre-defined transforms are available: - 0: original (normal case, with punc) - 1: recog (upper, no punc) - 2: upper_only_alpha (upper, no punc) - 3: lower_only_alpha (lower, no punc) - 4: upper_all (upper, with punc) - 5: lower_all (lower, with punc) + 0: original (mixed-cased, with punc) + 1: upper_only_alpha (upper-cased, no punc) When the transform of text and pre_text match, we can use the whole pre_text as the prompt text. @@ -403,15 +325,21 @@ def triplet_text_sampling_with_context_list( pre_texts (List[str]): A list of pre_texts, whose first item is the groundtruth pre_text from books. + context_list: Optional[str] = None, + A list of biasing words separated by space + rare_word_list: Optional[str] = None, + A list of rare-words separated by space (used as distractors) transforms (List[Callable[[str], str]]): A list of possible transforms to be applied + Returns: + A dictionary of ref_text, pre_text, style_text Returns: str: A dictionary """ # import pdb; pdb.set_trace() assert len(texts) == len(pre_texts) assert len(texts) == 2 - + if context_list is not None: context_list = context_list.lower() @@ -426,9 +354,13 @@ def triplet_text_sampling_with_context_list( lower_only_alpha, lower_all_char, ] - - # sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob - sampling_weight = [0.7, 0.3, 0.0, 0.0] + + sampling_weight = [ + 0.7, + 0.3, + 0.0, + 0.0, + ] # Mixed-punc should have the largest sampling prob total_transforms = len(transforms) # do not use the recognized trans @@ -446,11 +378,10 @@ def triplet_text_sampling_with_context_list( pre_text = transforms[i_pre_text](pre_text) if i_text == i_pre_text: - style_text = gt_pre_text style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) else: # get the pre_text of same style as text - # For now, do not do transform to the style text + # For now, **don't** do transform to the style text style_text = gt_pre_text # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) style_text = get_substring(style_text, min_len=min_len_style, max_len=150) @@ -471,7 +402,7 @@ def get_pre_text_with_context_list( ) -> str: # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha # By a small proportion of time, use the substring of ref_text as pre_text - + if context_list != "" and context_list is not None: v = random.random() if v < 0.5: @@ -489,14 +420,14 @@ def get_pre_text_with_context_list( pre_text = " ".join(pre_text) elif v < 0.7: splitted = text.split() - sampling_weights = [len(w)**1.2 for w in splitted] - sampling_weights = [p/sum(sampling_weights) for p in sampling_weights] + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] i = random.randint(1, min(len(splitted), 20)) splitted = list(np.random.choice(splitted, i, p=sampling_weights)) - num_distractors = random.randint(0,70) + num_distractors = random.randint(0, 70) distractors = random.sample(rare_words_list, num_distractors) splitted += distractors - random.shuffle(splitted) # shuffle the list + random.shuffle(splitted) # shuffle the list pre_text = " ".join(splitted) else: pre_text = pre_text @@ -504,21 +435,21 @@ def get_pre_text_with_context_list( v = random.random() if v < 0.1: splitted = text.split() - sampling_weights = [len(w)**1.2 for w in splitted] - sampling_weights = [p/sum(sampling_weights) for p in sampling_weights] + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] i = random.randint(1, min(len(splitted), 20)) splitted = list(np.random.choice(splitted, i, p=sampling_weights)) pre_text = " ".join(splitted) - num_distractors = random.randint(0,70) + num_distractors = random.randint(0, 70) distractors = random.sample(rare_words_list, num_distractors) splitted += distractors - random.shuffle(splitted) # shuffle the list + random.shuffle(splitted) # shuffle the list elif v < 0.2: # full distractors num_distractors = random.randint(5, 100) distractors = random.sample(rare_words_list, num_distractors) - pre_text = " ".join(distractors) - + pre_text = " ".join(distractors) + elif v < 0.3: pre_text = get_substring(text, min_len=15, max_len=150) else: @@ -527,20 +458,19 @@ def get_pre_text_with_context_list( return pre_text - def get_pre_text_with_context_list2( text: str, pre_text: str, context_list: str, rare_words_list: List[str] = None, ) -> str: - # Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha + # Get the pre_text, either the ground truth preceding text or + # a list of words consisting of biasing words and distrators # By a small proportion of time, use the substring of ref_text as pre_text - + if context_list != "" and context_list is not None: v = random.random() if v < 0.4: - # correct + distractors # sample distractors num_distractors = random.randint(50, 100) distractors = random.sample(rare_words_list, num_distractors) @@ -554,14 +484,16 @@ def get_pre_text_with_context_list2( pre_text = " ".join(pre_text) elif v < 0.55: splitted = text.split() - sampling_weights = [len(w)**1.2 for w in splitted] - sampling_weights = [p/sum(sampling_weights) for p in sampling_weights] + sampling_weights = [ + len(w) ** 1.2 for w in splitted + ] # longer words with higher weights + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] i = random.randint(1, min(len(splitted), 20)) splitted = list(np.random.choice(splitted, i, p=sampling_weights)) - num_distractors = random.randint(50,100) + num_distractors = random.randint(50, 100) distractors = random.sample(rare_words_list, num_distractors) splitted += distractors - random.shuffle(splitted) # shuffle the list + random.shuffle(splitted) # shuffle the list pre_text = " ".join(splitted) else: pre_text = pre_text @@ -569,21 +501,20 @@ def get_pre_text_with_context_list2( v = random.random() if v < 0.3: splitted = text.split() - sampling_weights = [len(w)**1.2 for w in splitted] - sampling_weights = [p/sum(sampling_weights) for p in sampling_weights] + sampling_weights = [len(w) ** 1.2 for w in splitted] + sampling_weights = [p / sum(sampling_weights) for p in sampling_weights] i = random.randint(1, min(len(splitted), 20)) splitted = list(np.random.choice(splitted, i, p=sampling_weights)) pre_text = " ".join(splitted) - num_distractors = random.randint(50,100) + num_distractors = random.randint(50, 100) distractors = random.sample(rare_words_list, num_distractors) splitted += distractors - random.shuffle(splitted) # shuffle the list + random.shuffle(splitted) # shuffle the list elif v < 0.4: # full distractors num_distractors = random.randint(5, 100) distractors = random.sample(rare_words_list, num_distractors) - pre_text = " ".join(distractors) - + pre_text = " ".join(distractors) elif v < 0.6: pre_text = get_substring(text, min_len=15, max_len=150) else: @@ -592,160 +523,6 @@ def get_pre_text_with_context_list2( return pre_text - -def joint_triplet_text_sampling( - texts: List[str], - pre_texts: List[str], - transforms: Optional[List[Callable[[str], str]]] = None, - min_len_style: Optional[int] = 80, -) -> Dict[str, str]: - """This function generates a triplet of - (pre_text, style_text, ref_text). The style of pre_text, style_text - and ref_text should **always** match. - Suppose we have 3 different transforms A,B,C, and the groundtruth - text and pre_text are referred to as text and pre_text. - The following three tuples are all valid: - - (A(pre_text), A(style_text), A(text)) - (B(pre_text), B(style_text), B(text)) - (C(pre_text), C(style_text), C(text)) - ... - - If transforms is not given, the following pre-defined transforms - are available: - 0: original (normal case, with punc) - 1: recog (upper, no punc) - 2: upper_only_alpha (upper, no punc) - 3: lower_only_alpha (lower, no punc) - 4: upper_all (upper, with punc) - 5: lower_all (lower, with punc) - - When the transform of text and pre_text match, we can use the whole - pre_text as the prompt text. - - Args: - texts (List[str]): - A list of ref_texts whose first item is the ground truth - text from books. - pre_texts (List[str]): - A list of pre_texts, whose first item is the groundtruth - pre_text from books. - transforms (List[Callable[[str], str]]): A list of possible transforms to be applied - - Returns: - str: A dictionary - """ - # import pdb; pdb.set_trace() - assert len(texts) == len(pre_texts) - assert len(texts) == 2 - - # we assume the first item to be ground truth - gt_text = texts[0] - gt_pre_text = pre_texts[0] - - if transforms is None: - transforms = [ - lambda x: x, # return it self - upper_only_alpha, - lower_only_alpha, - lower_all_char, - ] - - sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob - - total_transforms = len(transforms) # do not use the recognized trans - - # Select a transformation randomly - i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0] - - # get the normalized text and pre_text - text = transforms[i_text](gt_text) - pre_text = transforms[i_text](gt_pre_text) - - style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) - - return { - "text": train_text_normalization(text), - "pre_text": train_text_normalization(pre_text), - "style_text": train_text_normalization(style_text), - "transform_ids": i_text, - } - - -def triplet_style_text_sampling( - texts: List[str], - pre_texts: List[str], - transforms: Optional[List[Callable[[str], str]]] = None, - min_len_style: Optional[int] = 80, -) -> Dict[str, str]: - """This function generates a triplet of - (pre_text, style_text, ref_text). The style of style_text and ref_text - should always match, whereas the style of pre_text is fixed to mixed-trans. - Suppose we have 3 different transforms A,B,C, and the groundtruth - text and pre_text are referred to as text and pre_text. - The following three tuples are all valid: - - (gt_pre_text, B(style_text), B(text)) - (gt_pre_text, C(style_text), C(text)) - (gt_pre_text, A(style_text), A(text)) - ... - - If transforms is not given, the following pre-defined transforms - are available: - 0: original (normal case, with punc) - 1: recog (upper, no punc) - 2: upper_only_alpha (upper, no punc) - 3: lower_only_alpha (lower, no punc) - 4: upper_all (upper, with punc) - 5: lower_all (lower, with punc) - - Args: - texts (List[str]): - A list of ref_texts whose first item is the ground truth - text from books. - pre_texts (List[str]): - A list of pre_texts, whose first item is the groundtruth - pre_text from books. - transforms (List[Callable[[str], str]]): A list of possible transforms to be applied - - Returns: - str: A dictionary - """ - # import pdb; pdb.set_trace() - assert len(texts) == len(pre_texts) - assert len(texts) == 2 - - # we assume the first item to be ground truth - gt_text = texts[0] - gt_pre_text = pre_texts[0] - - if transforms is None: - transforms = [ - lambda x: x, # return it self - upper_only_alpha, - lower_only_alpha, - lower_all_char, - ] - - sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob - total_transforms = len(transforms) # do not use the recognized trans - - # Select a transformation randomly - t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0] - - # get the normalized text - text = transforms[t_id](gt_text) - # get the original un-processed style text - style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150) - - return { - "text": train_text_normalization(text), - "pre_text": train_text_normalization(gt_pre_text), - "style_text": train_text_normalization(style_text), - "transform_ids": t_id, - } - - def naive_triplet_text_sampling( texts: List[str], pre_texts: List[str], @@ -753,18 +530,13 @@ def naive_triplet_text_sampling( rare_word_list: List[str] = None, min_len_style: Optional[int] = 120, ): + # The most simplest text sampling function, used only for + # evaluation, use a fixed sentence as the style text return { "text": train_text_normalization(texts[0]), "pre_text": train_text_normalization(pre_texts[0]), - #"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE", - #"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.", - #"pre_text": "Hello, my friend. "*50, - #"style_text": train_text_normalization(pre_texts[0][:200]), "style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?", - #"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.", - #"style_text": "Mixed-case English transcription, with punctuation.", - # "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)), "transform_ids": 0, } @@ -775,11 +547,11 @@ def random_shuffle_subset( p_mask: float = 0.05, ) -> List[str]: """ - Randomly shuffle the subset by probability p, which means that p% of the samples + Randomly shuffle the subset by probability `p`, which means that p% of the samples in the original batch are shuffled, the others are kept in the original order. - - With a probability of p_mask, replace the original string with an empty string. - + + With a probability of `p_mask`, replace the original string with an empty string. + """ num_to_shuffle = int(len(data) * p) @@ -789,9 +561,9 @@ def random_shuffle_subset( for id, item in zip(id_to_shuffle, item_to_shuffle): data[id] = item - + # Randomly mask a proportion of the data to empty string - if p_mask > 0: + if p_mask > 0: for i in range(len(data)): if random.random() < p_mask: data[i] = "" @@ -809,16 +581,6 @@ if __name__ == "__main__": "EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?", "EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG", ] - # for i in range(10): - # print(f"Run: {i}") - # print(triplet_text_sampling(texts, pre_texts)) - - import time - start = time.time() - data = [str(i) for i in range(30)] - random.shuffle(data) - print(data) - for i in range(1): - shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1) - print(shuffled) - print((time.time() - start)/100) + for i in range(10): + print(f"Run: {i}") + print(triplet_text_sampling(texts, pre_texts)) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_subformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_subformer.py deleted file mode 100644 index f2caf8fbe..000000000 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/transcribe_subformer.py +++ /dev/null @@ -1,477 +0,0 @@ -import argparse -import logging -import math -import warnings -from pathlib import Path -from typing import List -from tqdm import tqdm - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from lhotse import load_manifest, Fbank - -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from text_normalization import ( - ref_text_normalization, - remove_non_alphabetic, - upper_only_alpha, - upper_all_char, - lower_all_char, - lower_only_alpha, - train_text_normalization, -) -from train_subformer_with_style import ( - add_model_arguments, - get_params, - get_tokenizer, - get_transducer_model, - _encode_text_as_tokens, -) - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="The experiment dir", - ) - - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - ) - - parser.add_argument( - "--manifest-dir", - type=str, - default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz", - help="""This is the manfiest for long audio transcription. - It is intended to be sored, i.e first sort by recording ID and then sort by - start timestamp""" - ) - - parser.add_argument( - "--segment-length", - type=float, - default=30.0, - ) - - parser.add_argument( - "--use-pre-text", - type=str2bool, - default=False, - help="Whether use pre-text when decoding the current chunk" - ) - - parser.add_argument( - "--use-style-prompt", - type=str2bool, - default=True, - help="Use style prompt when evaluation" - ) - - parser.add_argument( - "--pre-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], - default="mixed-punc", - help="The style of content prompt, i.e pre_text" - ) - - parser.add_argument( - "--style-text-transform", - type=str, - choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"], - default="mixed-punc", - help="The style of style prompt, i.e style_text" - ) - - parser.add_argument( - "--num-history", - type=int, - default=2, - help="How many previous chunks to look if using pre-text for decoding" - ) - - parser.add_argument( - "--use-gt-pre-text", - type=str2bool, - default=False, - help="Whether use gt pre text when using content prompt", - ) - - parser.add_argument( - "--post-normalization", - type=str2bool, - default=True, - ) - - add_model_arguments(parser) - - return parser - -def _apply_style_transform(text: List[str], transform: str) -> List[str]: - """Apply transform to a list of text. By default, the text are in - ground truth format, i.e mixed-punc. - - Args: - text (List[str]): Input text string - transform (str): Transform to be applied - - Returns: - List[str]: _description_ - """ - if transform == "mixed-punc": - return text - elif transform == "upper-no-punc": - return [upper_only_alpha(s) for s in text] - elif transform == "lower-no-punc": - return [lower_only_alpha(s) for s in text] - elif transform == "lower-punc": - return [lower_all_char(s) for s in text] - else: - raise NotImplementedError(f"Unseen transform: {transform}") - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - params.res_dir = params.exp_dir / "long_audio_transcribe" - params.res_dir.mkdir(exist_ok=True) - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "beam_search" in params.method: - params.suffix += ( - f"-{params.method}-beam-size-{params.beam_size}" - ) - - if params.use_pre_text: - if params.use_gt_pre_text: - params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}" - else: - params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}" - - - book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "") - setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - text_sp = spm.SentencePieceProcessor() - text_sp.load(params.text_encoder_bpe_model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - # load manifest - manifest = load_manifest(params.manifest_dir) - - results = [] - count = 0 - - last_recording = "" - last_end = -1 - history = [] - num_pre_texts = [] - - for cut in manifest: - if cut.has_features: - feat = cut.load_features() - feat_lens = cut.num_frames - else: - feat = cut.compute_features(extractor=Fbank()) - feat_lens = feat.shape[0] - - - cur_recording = cut.recording.id - - if cur_recording != last_recording: - last_recording = cur_recording - history = [] # clean history - last_end = -1 - logging.info(f"Moving on to the next recording") - else: - if cut.start < last_end - 0.2: # overlap exits - logging.warning(f"An overlap exists between current cut and last cut") - logging.warning("Skipping this cut!") - continue - if cut.start > last_end + 10: - logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.") - - # prepare input - x = torch.tensor(feat, device=device).unsqueeze(0) - x_lens = torch.tensor([feat_lens,], device=device) - - if params.use_pre_text: - if params.num_history > 0: - pre_texts = history[-params.num_history:] - else: - pre_texts = [] - assert len(pre_texts) <= params.num_history - num_pre_texts.append(len(pre_texts)) - pre_texts = [train_text_normalization(" ".join(pre_texts))] - fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related." - style_texts = [fixed_sentence] - - pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform) - if params.use_style_prompt: - style_texts = _apply_style_transform(style_texts, params.style_text_transform) - - # encode pre_text - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens( - pre_texts=pre_texts, - style_texts=style_texts, - bpe_model=text_sp, - device=device, - max_tokens=1500, - ) - if params.num_history > 5: - logging.info(f"Shape of encoded texts: {pre_texts.shape} ") - - memory, memory_key_padding_mask = model.encode_text( - text=pre_texts, - style_lens=style_text_lens, - text_lens=pre_texts_lens, - ) # (T,B,C) - else: - memory = None - memory_key_padding_mask = None - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - encoder_out, encoder_out_lens = model.encode_audio( - feature=x, - feature_lens=x_lens, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - - if params.method == "greedy_search": - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - hyp = sp.decode(hyp_tokens)[0] # in string format - ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training - - # extend the history, the history here is in original format - if params.use_gt_pre_text: - history.append(ref_text) - else: - history.append(hyp) - last_end = cut.end # update the last end timestamp - - # append the current decoding result - hyp = hyp.split() - ref = ref_text.split() - results.append((cut.id, ref, hyp)) - - count += 1 - if count % 100 == 0: - logging.info(f"Cuts processed until now: {count}/{len(manifest)}") - logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}") - logging.info(f"A total of {count} cuts") - logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}") - - results = sorted(results) - recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt" - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False, - ) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - if params.post_normalization: - params.suffix += "-post-normalization" - - new_res = [] - for item in results: - id, ref, hyp = item - hyp = upper_only_alpha(" ".join(hyp)).split() - ref = upper_only_alpha(" ".join(ref)).split() - new_res.append((id,ref,hyp)) - - new_res = sorted(new_res) - recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" - store_transcripts(filename=recog_path, texts=new_res) - logging.info(f"The transcripts are stored in {recog_path}") - - errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False, - ) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) -if __name__=="__main__": - main() \ No newline at end of file