from typing import Callable, Dict, List, Optional, Union import random import numpy as np import torch from lhotse import validate from lhotse.cut import CutSet from lhotse.dataset import K2SpeechRecognitionDataset from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures from lhotse.utils import compute_num_frames, ifnone from torch.utils.data.dataloader import DataLoader, default_collate from text_normalization import ( upper_only_alpha, lower_only_alpha, upper_all_char, lower_all_char, train_text_normalization, ) class PromptASRDataset(torch.utils.data.Dataset): """This is a dataset for Prompt ASR. It supports the following features: 1. Select a tuple of (text, pre_text, style_text) randomly from a list of texts as supervisions. """ def __init__( self, return_cuts: bool = False, cut_transforms: List[Callable[[CutSet], CutSet]] = None, input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: BatchIO = PrecomputedFeatures(), text_sampling_func: Optional[Callable[[List[str]], str]] = None, ): """ Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py for more details. :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut objects used to create that batch. :param cut_transforms: A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc. :param input_transforms: A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc. :param input_strategy: Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk. :param text_sampling_func: Sampling a text as transcription from a list of texts. """ super().__init__() # Initialize the fields self.return_cuts = return_cuts self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, []) self.input_strategy = input_strategy # a text sampling function self.text_sampling_func = text_sampling_func def __getitem__( self, cuts: CutSet ) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints of max_frames and max_cuts. """ validate_for_asr(cuts) # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts # the supervision boundaries. for tnfm in self.cut_transforms: cuts = tnfm(cuts) # Sort the cuts again after transforms cuts = cuts.sort_by_duration(ascending=False) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. input_tpl = self.input_strategy(cuts) if len(input_tpl) == 3: # An input strategy with fault tolerant audio reading mode. # "cuts" may be a subset of the original "cuts" variable, # that only has cuts for which we succesfully read the audio. inputs, _, cuts = input_tpl else: inputs, _ = input_tpl # Get a dict of tensors that encode the positional information about supervisions # in the batch of feature matrices. The tensors are named "sequence_idx", # "start_frame/sample" and "num_frames/samples". supervision_intervals = self.input_strategy.supervision_intervals(cuts) # Apply all available transforms on the inputs, i.e. either audio or features. # This could be feature extraction, global MVN, SpecAugment, etc. segments = torch.stack(list(supervision_intervals.values()), dim=1) for tnfm in self.input_transforms: inputs = tnfm(inputs, supervision_segments=segments) batch = { "inputs": inputs, "supervisions": default_collate( [ self.text_sampling_func( texts=supervision.texts, pre_texts=supervision.pre_texts, ) if self.text_sampling_func is not None else { "text": train_text_normalization(supervision.texts[0]), "pre_text": train_text_normalization( supervision.pre_texts[0] ), "style_text": train_text_normalization( supervision.pre_texts[0] ), "transform_ids": 0, } for sequence_idx, cut in enumerate(cuts) for supervision in cut.supervisions ] ), } # Update the 'supervisions' field with sequence_idx and start/num frames/samples batch["supervisions"].update(supervision_intervals) if self.return_cuts: batch["supervisions"]["cut"] = [ cut for cut in cuts for sup in cut.supervisions ] has_word_alignments = all( s.alignment is not None and "word" in s.alignment for c in cuts for s in c.supervisions ) return batch def validate_for_asr(cuts: CutSet) -> None: validate(cuts) tol = 2e-3 # 1ms for cut in cuts: for supervision in cut.supervisions: assert supervision.start >= -tol, ( f"Supervisions starting before the cut are not supported for ASR" f" (sup id: {supervision.id}, cut id: {cut.id})" ) # Supervision start time is relative to Cut ... # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html # # 'supervision.end' is end of supervision inside the Cut assert supervision.end <= cut.duration + tol, ( f"Supervisions ending after the cut " f"are not supported for ASR" f" (sup id: {supervision.id}, cut id: {cut.id})" ) def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str: """A helper function that generates a random substring from a given string Args: s (str): Input string Returns: str: Returned substring """ min_len = min(len(s), min_len) start = random.randint(0, len(s) - min_len) end = min(start + max_len, random.randint(start + min_len, len(s))) return s[start:end] def triplet_text_sampling( texts: List[str], pre_texts: List[str], transforms: Optional[List[Callable[[str], str]]] = None, min_len_style: Optional[int] = 80, ) -> Dict[str, str]: """This function generates a triplet of (pre_text, style_text, ref_text). The style of style_text and ref_text should always match, whereas the style of pre_text is arbitrary. Suppose we have 3 different transforms A,B,C, and the groundtruth text and pre_text are referred to as text and pre_text. The following three tuples are all valid: (A(pre_text), B(style_text), B(text)) (A(pre_text), C(style_text), C(text)) (A(pre_text), A(style_text), A(text)) ... If transforms is not given, the following pre-defined transforms are available: 0: original (normal case, with punc) 1: recog (upper, no punc) 2: upper_only_alpha (upper, no punc) 3: lower_only_alpha (lower, no punc) 4: upper_all (upper, with punc) 5: lower_all (lower, with punc) When the transform of text and pre_text match, we can use the whole pre_text as the prompt text. Args: texts (List[str]): A list of ref_texts whose first item is the ground truth text from books. pre_texts (List[str]): A list of pre_texts, whose first item is the groundtruth pre_text from books. transforms (List[Callable[[str], str]]): A list of possible transforms to be applied Returns: str: A dictionary """ # import pdb; pdb.set_trace() assert len(texts) == len(pre_texts) assert len(texts) == 2 # we assume the first item to be ground truth gt_text = texts[0] gt_pre_text = pre_texts[0] if transforms is None: transforms = [ lambda x: x, # return it self upper_only_alpha, lower_only_alpha, lower_all_char, ] sampling_weight = [0.7, 0.3, 0.0, 0.0] # Mixed-punc should have the largest sampling prob total_transforms = len(transforms) # do not use the recognized trans # Select a transformation randomly i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) # get the normalized text and pre_text text = transforms[i_text](gt_text) pre_text = transforms[i_pre_text](gt_pre_text) if i_text == i_pre_text: style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) else: # get the pre_text of same style as text # For now, do not do transform to the style text style_text = gt_pre_text # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) style_text = get_substring(style_text, min_len=min_len_style, max_len=150) return { "text": train_text_normalization(text), "pre_text": train_text_normalization(pre_text), "style_text": train_text_normalization(style_text), "transform_ids": i_text, } def multi_ref_text_triplet_text_sampling( texts: List[str], pre_texts: List[str], context_list: Optional[str] = None, rare_word_list: Optional[List[str]] = None, transforms: Optional[List[Callable[[str], str]]] = None, min_len_style: Optional[int] = 80, ) -> Dict[str, str]: """This function generates a triplet of (pre_text, style_text, ref_text). The style of style_text and ref_text should always match, whereas the style of pre_text is arbitrary. Suppose we have 3 different transforms A,B,C, and the groundtruth text and pre_text are referred to as text and pre_text. The following three tuples are all valid: (A(pre_text), B(style_text), B(text)) (A(pre_text), C(style_text), C(text)) (A(pre_text), A(style_text), A(text)) ... If transforms is not given, the following pre-defined transforms are available: 0: original (normal case, with punc) 1: recog (upper, no punc) 2: upper_only_alpha (upper, no punc) 3: lower_only_alpha (lower, no punc) 4: upper_all (upper, with punc) 5: lower_all (lower, with punc) When the transform of text and pre_text match, we can use the whole pre_text as the prompt text. Args: texts (List[str]): A list of ref_texts whose first item is the ground truth text from books. pre_texts (List[str]): A list of pre_texts, whose first item is the groundtruth pre_text from books. transforms (List[Callable[[str], str]]): A list of possible transforms to be applied Returns: str: A dictionary """ assert len(texts) == 3 # we assume the first item to be ground truth, the third item to be the # decoding results with prompts if random.random() < 0.5: gt_text = texts[0] else: gt_text = texts[2] # decoding res with prompt gt_pre_text = pre_texts[0] if transforms is None: transforms = [ lambda x: x, # return it self upper_only_alpha, lower_only_alpha, lower_all_char, ] sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob total_transforms = len(transforms) # do not use the recognized trans # Select a transformation randomly i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight) # get the normalized text and pre_text text = transforms[i_text](gt_text) pre_text = transforms[i_pre_text](gt_pre_text) if i_text == i_pre_text: style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) else: # get the pre_text of same style as text # For now, do not do transform to the style text style_text = gt_pre_text # style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text) style_text = get_substring(style_text, min_len=min_len_style, max_len=150) return { "text": train_text_normalization(text), "pre_text": train_text_normalization(pre_text), "style_text": train_text_normalization(style_text), "transform_ids": i_text, } def joint_triplet_text_sampling( texts: List[str], pre_texts: List[str], transforms: Optional[List[Callable[[str], str]]] = None, min_len_style: Optional[int] = 80, ) -> Dict[str, str]: """This function generates a triplet of (pre_text, style_text, ref_text). The style of pre_text, style_text and ref_text should **always** match. Suppose we have 3 different transforms A,B,C, and the groundtruth text and pre_text are referred to as text and pre_text. The following three tuples are all valid: (A(pre_text), A(style_text), A(text)) (B(pre_text), B(style_text), B(text)) (C(pre_text), C(style_text), C(text)) ... If transforms is not given, the following pre-defined transforms are available: 0: original (normal case, with punc) 1: recog (upper, no punc) 2: upper_only_alpha (upper, no punc) 3: lower_only_alpha (lower, no punc) 4: upper_all (upper, with punc) 5: lower_all (lower, with punc) When the transform of text and pre_text match, we can use the whole pre_text as the prompt text. Args: texts (List[str]): A list of ref_texts whose first item is the ground truth text from books. pre_texts (List[str]): A list of pre_texts, whose first item is the groundtruth pre_text from books. transforms (List[Callable[[str], str]]): A list of possible transforms to be applied Returns: str: A dictionary """ # import pdb; pdb.set_trace() assert len(texts) == len(pre_texts) assert len(texts) == 2 # we assume the first item to be ground truth gt_text = texts[0] gt_pre_text = pre_texts[0] if transforms is None: transforms = [ lambda x: x, # return it self upper_only_alpha, lower_only_alpha, lower_all_char, ] sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob total_transforms = len(transforms) # do not use the recognized trans # Select a transformation randomly i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0] # get the normalized text and pre_text text = transforms[i_text](gt_text) pre_text = transforms[i_text](gt_pre_text) style_text = get_substring(pre_text, min_len=min_len_style, max_len=150) return { "text": train_text_normalization(text), "pre_text": train_text_normalization(pre_text), "style_text": train_text_normalization(style_text), "transform_ids": i_text, } def triplet_style_text_sampling( texts: List[str], pre_texts: List[str], transforms: Optional[List[Callable[[str], str]]] = None, min_len_style: Optional[int] = 80, ) -> Dict[str, str]: """This function generates a triplet of (pre_text, style_text, ref_text). The style of style_text and ref_text should always match, whereas the style of pre_text is fixed to mixed-trans. Suppose we have 3 different transforms A,B,C, and the groundtruth text and pre_text are referred to as text and pre_text. The following three tuples are all valid: (gt_pre_text, B(style_text), B(text)) (gt_pre_text, C(style_text), C(text)) (gt_pre_text, A(style_text), A(text)) ... If transforms is not given, the following pre-defined transforms are available: 0: original (normal case, with punc) 1: recog (upper, no punc) 2: upper_only_alpha (upper, no punc) 3: lower_only_alpha (lower, no punc) 4: upper_all (upper, with punc) 5: lower_all (lower, with punc) Args: texts (List[str]): A list of ref_texts whose first item is the ground truth text from books. pre_texts (List[str]): A list of pre_texts, whose first item is the groundtruth pre_text from books. transforms (List[Callable[[str], str]]): A list of possible transforms to be applied Returns: str: A dictionary """ # import pdb; pdb.set_trace() assert len(texts) == len(pre_texts) assert len(texts) == 2 # we assume the first item to be ground truth gt_text = texts[0] gt_pre_text = pre_texts[0] if transforms is None: transforms = [ lambda x: x, # return it self upper_only_alpha, lower_only_alpha, lower_all_char, ] sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob total_transforms = len(transforms) # do not use the recognized trans # Select a transformation randomly t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0] # get the normalized text text = transforms[t_id](gt_text) # get the original un-processed style text style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150) return { "text": train_text_normalization(text), "pre_text": train_text_normalization(gt_pre_text), "style_text": train_text_normalization(style_text), "transform_ids": t_id, } def naive_triplet_text_sampling( texts: List[str], pre_texts: List[str], context_list: str = None, rare_word_list: List[str] = None, min_len_style: Optional[int] = 120, ): return { "text": train_text_normalization(texts[0]), "pre_text": train_text_normalization(pre_texts[0]), #"pre_text": "", #"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE", #"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.", #"pre_text": "Hello, my friend. "*50, #"style_text": train_text_normalization(pre_texts[0][:200]), "style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?", #"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.", #"style_text": "Mixed-case English transcription, with punctuation.", #"style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)), "transform_ids": 0, } def random_shuffle_subset( data: List[str], p: float = 0.2, p_mask: float = 0.05, ) -> List[str]: """ Randomly shuffle the subset by probability p, which means that p% of the samples in the original batch are shuffled, the others are kept in the original order. With a probability of p_mask, replace the original string with an empty string. """ num_to_shuffle = int(len(data) * p) id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False) item_to_shuffle = [data[id] for id in id_to_shuffle] random.shuffle(item_to_shuffle) for id, item in zip(id_to_shuffle, item_to_shuffle): data[id] = item # Randomly mask a proportion of the data to empty string if p_mask > 0: for i in range(len(data)): if random.random() < p_mask: data[i] = "" return data if __name__ == "__main__": texts = [ "AA, BB, cC, dD!", "AA BB CC DD", ] pre_texts = [ "EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?", "EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG", ] # for i in range(10): # print(f"Run: {i}") # print(triplet_text_sampling(texts, pre_texts)) import time start = time.time() data = [str(i) for i in range(30)] random.shuffle(data) print(data) for i in range(1): shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1) print(shuffled) print((time.time() - start)/100)