add documentation to different text sampling function

This commit is contained in:
marcoyang1998 2023-09-20 09:57:03 +08:00
parent 6579800720
commit 93461fb77e
2 changed files with 117 additions and 832 deletions

View File

@ -1,23 +1,38 @@
from typing import Callable, Dict, List, Optional, Union
import random
import numpy as np
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from torch.utils.data.dataloader import DataLoader, default_collate
from text_normalization import (
remove_non_alphabetic,
upper_only_alpha,
lower_only_alpha,
upper_all_char,
lower_all_char,
lower_only_alpha,
remove_non_alphabetic,
train_text_normalization,
upper_all_char,
upper_only_alpha,
)
from torch.utils.data.dataloader import DataLoader, default_collate
class PromptASRDataset(torch.utils.data.Dataset):
@ -34,7 +49,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
rare_word_list: Optional[List[str]] = None
rare_word_list: Optional[List[str]] = None,
):
"""
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
@ -63,9 +78,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
self.text_sampling_func = text_sampling_func
self.rare_word_list = rare_word_list
def __getitem__(
self, cuts: CutSet
) -> Dict[str, Union[torch.Tensor, List[str]]]:
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
@ -112,15 +125,15 @@ class PromptASRDataset(torch.utils.data.Dataset):
self.text_sampling_func(
texts=supervision.texts,
pre_texts=supervision.pre_texts,
context_list=supervision.context_list if "context_list" in supervision.custom else None,
context_list=supervision.context_list
if "context_list" in supervision.custom
else None,
rare_word_list=self.rare_word_list,
)
if self.text_sampling_func is not None
else {
"text": train_text_normalization(supervision.texts[0]),
"pre_text": train_text_normalization(
supervision.pre_texts[0]
),
"pre_text": train_text_normalization(supervision.pre_texts[0]),
"style_text": train_text_normalization(
supervision.pre_texts[0]
),
@ -192,27 +205,22 @@ def triplet_text_sampling(
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
) -> Dict[str, str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
should **always** match, whereas the style of pre_text is arbitrary.
Suppose we have 2 different transforms A,B, and the preceding text is
referred to as pre_text. The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
(A(pre_text), A(style_text), A(ref_text))
(A(pre_text), B(style_text), B(ref_text))
(A(pre_text), A(style_text), A(ref_text))
(B(pre_text), B(style_text), B(ref_text))
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
0: original (mixed-cased, with punc)
1: upper_only_alpha (upper-cased, no punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
@ -224,12 +232,15 @@ def triplet_text_sampling(
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
context_list: Optional[str] = None,
A list of biasing words separated by space
rare_word_list: Optional[str] = None,
A list of rare-words separated by space (used as distractors)
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
A dictionary of ref_text, pre_text, style_text
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
@ -244,13 +255,17 @@ def triplet_text_sampling(
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
sampling_weight = [
0.7,
0.3,
0.0,
0.0,
] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
# Randomly sample transforms
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
@ -261,97 +276,7 @@ def triplet_text_sampling(
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_text_sampling2(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_pre_text](gt_pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
# For now, **don't** do transform to the style text, because we do it after the dataloader
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
@ -373,25 +298,22 @@ def triplet_text_sampling_with_context_list(
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(pre_text, style_text, ref_text). The pre_text is either the preceding text
or a list of words (context words + distractors).
The style of style_text and ref_text should **always** match, whereas
the style of pre_text is arbitrary.
Suppose we have 2 different transforms A,B, and the preceding text is
referred to as pre_text. The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
(A(pre_text), A(style_text), A(ref_text))
(A(pre_text), B(style_text), B(ref_text))
(A(pre_text), A(style_text), A(ref_text))
(B(pre_text), B(style_text), B(ref_text))
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
0: original (mixed-cased, with punc)
1: upper_only_alpha (upper-cased, no punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
@ -403,15 +325,21 @@ def triplet_text_sampling_with_context_list(
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
context_list: Optional[str] = None,
A list of biasing words separated by space
rare_word_list: Optional[str] = None,
A list of rare-words separated by space (used as distractors)
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
A dictionary of ref_text, pre_text, style_text
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
if context_list is not None:
context_list = context_list.lower()
@ -426,9 +354,13 @@ def triplet_text_sampling_with_context_list(
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
sampling_weight = [
0.7,
0.3,
0.0,
0.0,
] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
@ -446,11 +378,10 @@ def triplet_text_sampling_with_context_list(
pre_text = transforms[i_pre_text](pre_text)
if i_text == i_pre_text:
style_text = gt_pre_text
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
# For now, **don't** do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
@ -471,7 +402,7 @@ def get_pre_text_with_context_list(
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.5:
@ -489,14 +420,14 @@ def get_pre_text_with_context_list(
pre_text = " ".join(pre_text)
elif v < 0.7:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
sampling_weights = [len(w) ** 1.2 for w in splitted]
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,70)
num_distractors = random.randint(0, 70)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
@ -504,21 +435,21 @@ def get_pre_text_with_context_list(
v = random.random()
if v < 0.1:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
sampling_weights = [len(w) ** 1.2 for w in splitted]
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(0,70)
num_distractors = random.randint(0, 70)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
random.shuffle(splitted) # shuffle the list
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
@ -527,20 +458,19 @@ def get_pre_text_with_context_list(
return pre_text
def get_pre_text_with_context_list2(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# Get the pre_text, either the ground truth preceding text or
# a list of words consisting of biasing words and distrators
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.4:
# correct + distractors
# sample distractors
num_distractors = random.randint(50, 100)
distractors = random.sample(rare_words_list, num_distractors)
@ -554,14 +484,16 @@ def get_pre_text_with_context_list2(
pre_text = " ".join(pre_text)
elif v < 0.55:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
sampling_weights = [
len(w) ** 1.2 for w in splitted
] # longer words with higher weights
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(50,100)
num_distractors = random.randint(50, 100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
@ -569,21 +501,20 @@ def get_pre_text_with_context_list2(
v = random.random()
if v < 0.3:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
sampling_weights = [len(w) ** 1.2 for w in splitted]
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(50,100)
num_distractors = random.randint(50, 100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
random.shuffle(splitted) # shuffle the list
elif v < 0.4:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
pre_text = " ".join(distractors)
elif v < 0.6:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
@ -592,160 +523,6 @@ def get_pre_text_with_context_list2(
return pre_text
def joint_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of pre_text, style_text
and ref_text should **always** match.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), A(style_text), A(text))
(B(pre_text), B(style_text), B(text))
(C(pre_text), C(style_text), C(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_text](gt_pre_text)
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_style_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is fixed to mixed-trans.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(gt_pre_text, B(style_text), B(text))
(gt_pre_text, C(style_text), C(text))
(gt_pre_text, A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
# get the normalized text
text = transforms[t_id](gt_text)
# get the original un-processed style text
style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(gt_pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": t_id,
}
def naive_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
@ -753,18 +530,13 @@ def naive_triplet_text_sampling(
rare_word_list: List[str] = None,
min_len_style: Optional[int] = 120,
):
# The most simplest text sampling function, used only for
# evaluation, use a fixed sentence as the style text
return {
"text": train_text_normalization(texts[0]),
"pre_text": train_text_normalization(pre_texts[0]),
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"pre_text": "Hello, my friend. "*50,
#"style_text": train_text_normalization(pre_texts[0][:200]),
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"style_text": "Mixed-case English transcription, with punctuation.",
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
"transform_ids": 0,
}
@ -775,11 +547,11 @@ def random_shuffle_subset(
p_mask: float = 0.05,
) -> List[str]:
"""
Randomly shuffle the subset by probability p, which means that p% of the samples
Randomly shuffle the subset by probability `p`, which means that p% of the samples
in the original batch are shuffled, the others are kept in the original order.
With a probability of p_mask, replace the original string with an empty string.
With a probability of `p_mask`, replace the original string with an empty string.
"""
num_to_shuffle = int(len(data) * p)
@ -789,9 +561,9 @@ def random_shuffle_subset(
for id, item in zip(id_to_shuffle, item_to_shuffle):
data[id] = item
# Randomly mask a proportion of the data to empty string
if p_mask > 0:
if p_mask > 0:
for i in range(len(data)):
if random.random() < p_mask:
data[i] = ""
@ -809,16 +581,6 @@ if __name__ == "__main__":
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
]
# for i in range(10):
# print(f"Run: {i}")
# print(triplet_text_sampling(texts, pre_texts))
import time
start = time.time()
data = [str(i) for i in range(30)]
random.shuffle(data)
print(data)
for i in range(1):
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
print(shuffled)
print((time.time() - start)/100)
for i in range(10):
print(f"Run: {i}")
print(triplet_text_sampling(texts, pre_texts))

View File

@ -1,477 +0,0 @@
import argparse
import logging
import math
import warnings
from pathlib import Path
from typing import List
from tqdm import tqdm
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from lhotse import load_manifest, Fbank
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from text_normalization import (
ref_text_normalization,
remove_non_alphabetic,
upper_only_alpha,
upper_all_char,
lower_all_char,
lower_only_alpha,
train_text_normalization,
)
from train_subformer_with_style import (
add_model_arguments,
get_params,
get_tokenizer,
get_transducer_model,
_encode_text_as_tokens,
)
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
help="""This is the manfiest for long audio transcription.
It is intended to be sored, i.e first sort by recording ID and then sort by
start timestamp"""
)
parser.add_argument(
"--segment-length",
type=float,
default=30.0,
)
parser.add_argument(
"--use-pre-text",
type=str2bool,
default=False,
help="Whether use pre-text when decoding the current chunk"
)
parser.add_argument(
"--use-style-prompt",
type=str2bool,
default=True,
help="Use style prompt when evaluation"
)
parser.add_argument(
"--pre-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
default="mixed-punc",
help="The style of content prompt, i.e pre_text"
)
parser.add_argument(
"--style-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
default="mixed-punc",
help="The style of style prompt, i.e style_text"
)
parser.add_argument(
"--num-history",
type=int,
default=2,
help="How many previous chunks to look if using pre-text for decoding"
)
parser.add_argument(
"--use-gt-pre-text",
type=str2bool,
default=False,
help="Whether use gt pre text when using content prompt",
)
parser.add_argument(
"--post-normalization",
type=str2bool,
default=True,
)
add_model_arguments(parser)
return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.res_dir = params.exp_dir / "long_audio_transcribe"
params.res_dir.mkdir(exist_ok=True)
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.method:
params.suffix += (
f"-{params.method}-beam-size-{params.beam_size}"
)
if params.use_pre_text:
if params.use_gt_pre_text:
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
else:
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
text_sp = spm.SentencePieceProcessor()
text_sp.load(params.text_encoder_bpe_model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
# load manifest
manifest = load_manifest(params.manifest_dir)
results = []
count = 0
last_recording = ""
last_end = -1
history = []
num_pre_texts = []
for cut in manifest:
if cut.has_features:
feat = cut.load_features()
feat_lens = cut.num_frames
else:
feat = cut.compute_features(extractor=Fbank())
feat_lens = feat.shape[0]
cur_recording = cut.recording.id
if cur_recording != last_recording:
last_recording = cur_recording
history = [] # clean history
last_end = -1
logging.info(f"Moving on to the next recording")
else:
if cut.start < last_end - 0.2: # overlap exits
logging.warning(f"An overlap exists between current cut and last cut")
logging.warning("Skipping this cut!")
continue
if cut.start > last_end + 10:
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
# prepare input
x = torch.tensor(feat, device=device).unsqueeze(0)
x_lens = torch.tensor([feat_lens,], device=device)
if params.use_pre_text:
if params.num_history > 0:
pre_texts = history[-params.num_history:]
else:
pre_texts = []
assert len(pre_texts) <= params.num_history
num_pre_texts.append(len(pre_texts))
pre_texts = [train_text_normalization(" ".join(pre_texts))]
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
style_texts = [fixed_sentence]
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
# encode pre_text
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
pre_texts=pre_texts,
style_texts=style_texts,
bpe_model=text_sp,
device=device,
max_tokens=1500,
)
if params.num_history > 5:
logging.info(f"Shape of encoded texts: {pre_texts.shape} ")
memory, memory_key_padding_mask = model.encode_text(
text=pre_texts,
style_lens=style_text_lens,
text_lens=pre_texts_lens,
) # (T,B,C)
else:
memory = None
memory_key_padding_mask = None
with warnings.catch_warnings():
warnings.simplefilter("ignore")
encoder_out, encoder_out_lens = model.encode_audio(
feature=x,
feature_lens=x_lens,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
if params.method == "greedy_search":
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
hyp = sp.decode(hyp_tokens)[0] # in string format
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
# extend the history, the history here is in original format
if params.use_gt_pre_text:
history.append(ref_text)
else:
history.append(hyp)
last_end = cut.end # update the last end timestamp
# append the current decoding result
hyp = hyp.split()
ref = ref_text.split()
results.append((cut.id, ref, hyp))
count += 1
if count % 100 == 0:
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
logging.info(f"A total of {count} cuts")
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
results = sorted(results)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if params.post_normalization:
params.suffix += "-post-normalization"
new_res = []
for item in results:
id, ref, hyp = item
hyp = upper_only_alpha(" ".join(hyp)).split()
ref = upper_only_alpha(" ".join(ref)).split()
new_res.append((id,ref,hyp))
new_res = sorted(new_res)
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
store_transcripts(filename=recog_path, texts=new_res)
logging.info(f"The transcripts are stored in {recog_path}")
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if __name__=="__main__":
main()