mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
add documentation to different text sampling function
This commit is contained in:
parent
6579800720
commit
93461fb77e
@ -1,23 +1,38 @@
|
|||||||
from typing import Callable, Dict, List, Optional, Union
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
import random
|
#
|
||||||
import numpy as np
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from lhotse import validate
|
from lhotse import validate
|
||||||
from lhotse.cut import CutSet
|
from lhotse.cut import CutSet
|
||||||
from lhotse.dataset import K2SpeechRecognitionDataset
|
from lhotse.dataset import K2SpeechRecognitionDataset
|
||||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||||
from lhotse.utils import compute_num_frames, ifnone
|
from lhotse.utils import compute_num_frames, ifnone
|
||||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
|
||||||
|
|
||||||
from text_normalization import (
|
from text_normalization import (
|
||||||
remove_non_alphabetic,
|
|
||||||
upper_only_alpha,
|
|
||||||
lower_only_alpha,
|
|
||||||
upper_all_char,
|
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
|
lower_only_alpha,
|
||||||
|
remove_non_alphabetic,
|
||||||
train_text_normalization,
|
train_text_normalization,
|
||||||
|
upper_all_char,
|
||||||
|
upper_only_alpha,
|
||||||
)
|
)
|
||||||
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||||
|
|
||||||
|
|
||||||
class PromptASRDataset(torch.utils.data.Dataset):
|
class PromptASRDataset(torch.utils.data.Dataset):
|
||||||
@ -34,7 +49,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
|||||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||||
rare_word_list: Optional[List[str]] = None
|
rare_word_list: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||||
@ -63,9 +78,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
|||||||
self.text_sampling_func = text_sampling_func
|
self.text_sampling_func = text_sampling_func
|
||||||
self.rare_word_list = rare_word_list
|
self.rare_word_list = rare_word_list
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||||
self, cuts: CutSet
|
|
||||||
) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
|
||||||
"""
|
"""
|
||||||
Return a new batch, with the batch size automatically determined using the constraints
|
Return a new batch, with the batch size automatically determined using the constraints
|
||||||
of max_frames and max_cuts.
|
of max_frames and max_cuts.
|
||||||
@ -112,15 +125,15 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
|||||||
self.text_sampling_func(
|
self.text_sampling_func(
|
||||||
texts=supervision.texts,
|
texts=supervision.texts,
|
||||||
pre_texts=supervision.pre_texts,
|
pre_texts=supervision.pre_texts,
|
||||||
context_list=supervision.context_list if "context_list" in supervision.custom else None,
|
context_list=supervision.context_list
|
||||||
|
if "context_list" in supervision.custom
|
||||||
|
else None,
|
||||||
rare_word_list=self.rare_word_list,
|
rare_word_list=self.rare_word_list,
|
||||||
)
|
)
|
||||||
if self.text_sampling_func is not None
|
if self.text_sampling_func is not None
|
||||||
else {
|
else {
|
||||||
"text": train_text_normalization(supervision.texts[0]),
|
"text": train_text_normalization(supervision.texts[0]),
|
||||||
"pre_text": train_text_normalization(
|
"pre_text": train_text_normalization(supervision.pre_texts[0]),
|
||||||
supervision.pre_texts[0]
|
|
||||||
),
|
|
||||||
"style_text": train_text_normalization(
|
"style_text": train_text_normalization(
|
||||||
supervision.pre_texts[0]
|
supervision.pre_texts[0]
|
||||||
),
|
),
|
||||||
@ -192,27 +205,22 @@ def triplet_text_sampling(
|
|||||||
rare_word_list: Optional[List[str]] = None,
|
rare_word_list: Optional[List[str]] = None,
|
||||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||||
min_len_style: Optional[int] = 80,
|
min_len_style: Optional[int] = 80,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str, str]:
|
||||||
"""This function generates a triplet of
|
"""This function generates a triplet of
|
||||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||||
should always match, whereas the style of pre_text is arbitrary.
|
should **always** match, whereas the style of pre_text is arbitrary.
|
||||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
Suppose we have 2 different transforms A,B, and the preceding text is
|
||||||
text and pre_text are referred to as text and pre_text.
|
referred to as pre_text. The following three tuples are all valid:
|
||||||
The following three tuples are all valid:
|
|
||||||
|
|
||||||
(A(pre_text), B(style_text), B(text))
|
(A(pre_text), A(style_text), A(ref_text))
|
||||||
(A(pre_text), C(style_text), C(text))
|
(A(pre_text), B(style_text), B(ref_text))
|
||||||
(A(pre_text), A(style_text), A(text))
|
(A(pre_text), A(style_text), A(ref_text))
|
||||||
...
|
(B(pre_text), B(style_text), B(ref_text))
|
||||||
|
|
||||||
If transforms is not given, the following pre-defined transforms
|
If transforms is not given, the following pre-defined transforms
|
||||||
are available:
|
are available:
|
||||||
0: original (normal case, with punc)
|
0: original (mixed-cased, with punc)
|
||||||
1: recog (upper, no punc)
|
1: upper_only_alpha (upper-cased, no punc)
|
||||||
2: upper_only_alpha (upper, no punc)
|
|
||||||
3: lower_only_alpha (lower, no punc)
|
|
||||||
4: upper_all (upper, with punc)
|
|
||||||
5: lower_all (lower, with punc)
|
|
||||||
|
|
||||||
When the transform of text and pre_text match, we can use the whole
|
When the transform of text and pre_text match, we can use the whole
|
||||||
pre_text as the prompt text.
|
pre_text as the prompt text.
|
||||||
@ -224,12 +232,15 @@ def triplet_text_sampling(
|
|||||||
pre_texts (List[str]):
|
pre_texts (List[str]):
|
||||||
A list of pre_texts, whose first item is the groundtruth
|
A list of pre_texts, whose first item is the groundtruth
|
||||||
pre_text from books.
|
pre_text from books.
|
||||||
|
context_list: Optional[str] = None,
|
||||||
|
A list of biasing words separated by space
|
||||||
|
rare_word_list: Optional[str] = None,
|
||||||
|
A list of rare-words separated by space (used as distractors)
|
||||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A dictionary
|
A dictionary of ref_text, pre_text, style_text
|
||||||
"""
|
"""
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
assert len(texts) == len(pre_texts)
|
assert len(texts) == len(pre_texts)
|
||||||
assert len(texts) == 2
|
assert len(texts) == 2
|
||||||
|
|
||||||
@ -244,13 +255,17 @@ def triplet_text_sampling(
|
|||||||
lower_only_alpha,
|
lower_only_alpha,
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
]
|
]
|
||||||
|
|
||||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
sampling_weight = [
|
||||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
0.7,
|
||||||
|
0.3,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
] # Mixed-punc should have the largest sampling prob
|
||||||
|
|
||||||
total_transforms = len(transforms) # do not use the recognized trans
|
total_transforms = len(transforms) # do not use the recognized trans
|
||||||
|
|
||||||
# Select a transformation randomly
|
# Randomly sample transforms
|
||||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||||
|
|
||||||
# get the normalized text and pre_text
|
# get the normalized text and pre_text
|
||||||
@ -261,97 +276,7 @@ def triplet_text_sampling(
|
|||||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||||
else:
|
else:
|
||||||
# get the pre_text of same style as text
|
# get the pre_text of same style as text
|
||||||
# For now, do not do transform to the style text
|
# For now, **don't** do transform to the style text, because we do it after the dataloader
|
||||||
style_text = gt_pre_text
|
|
||||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
|
||||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"text": train_text_normalization(text),
|
|
||||||
"pre_text": train_text_normalization(pre_text),
|
|
||||||
"style_text": train_text_normalization(style_text),
|
|
||||||
"transform_ids": i_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def triplet_text_sampling2(
|
|
||||||
texts: List[str],
|
|
||||||
pre_texts: List[str],
|
|
||||||
context_list: Optional[str] = None,
|
|
||||||
rare_word_list: Optional[List[str]] = None,
|
|
||||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
|
||||||
min_len_style: Optional[int] = 80,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""This function generates a triplet of
|
|
||||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
|
||||||
should always match, whereas the style of pre_text is arbitrary.
|
|
||||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
|
||||||
text and pre_text are referred to as text and pre_text.
|
|
||||||
The following three tuples are all valid:
|
|
||||||
|
|
||||||
(A(pre_text), B(style_text), B(text))
|
|
||||||
(A(pre_text), C(style_text), C(text))
|
|
||||||
(A(pre_text), A(style_text), A(text))
|
|
||||||
...
|
|
||||||
|
|
||||||
If transforms is not given, the following pre-defined transforms
|
|
||||||
are available:
|
|
||||||
0: original (normal case, with punc)
|
|
||||||
1: recog (upper, no punc)
|
|
||||||
2: upper_only_alpha (upper, no punc)
|
|
||||||
3: lower_only_alpha (lower, no punc)
|
|
||||||
4: upper_all (upper, with punc)
|
|
||||||
5: lower_all (lower, with punc)
|
|
||||||
|
|
||||||
When the transform of text and pre_text match, we can use the whole
|
|
||||||
pre_text as the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]):
|
|
||||||
A list of ref_texts whose first item is the ground truth
|
|
||||||
text from books.
|
|
||||||
pre_texts (List[str]):
|
|
||||||
A list of pre_texts, whose first item is the groundtruth
|
|
||||||
pre_text from books.
|
|
||||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A dictionary
|
|
||||||
"""
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
assert len(texts) == len(pre_texts)
|
|
||||||
assert len(texts) == 2
|
|
||||||
|
|
||||||
# we assume the first item to be ground truth
|
|
||||||
gt_text = texts[0]
|
|
||||||
gt_pre_text = pre_texts[0]
|
|
||||||
|
|
||||||
if transforms is None:
|
|
||||||
transforms = [
|
|
||||||
lambda x: x, # return it self
|
|
||||||
upper_only_alpha,
|
|
||||||
lower_only_alpha,
|
|
||||||
lower_all_char,
|
|
||||||
]
|
|
||||||
|
|
||||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
|
||||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
|
||||||
|
|
||||||
total_transforms = len(transforms) # do not use the recognized trans
|
|
||||||
|
|
||||||
# Select a transformation randomly
|
|
||||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
|
||||||
|
|
||||||
# get the normalized text and pre_text
|
|
||||||
text = transforms[i_text](gt_text)
|
|
||||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
|
||||||
|
|
||||||
if i_text == i_pre_text:
|
|
||||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
|
||||||
else:
|
|
||||||
# get the pre_text of same style as text
|
|
||||||
# For now, do not do transform to the style text
|
|
||||||
style_text = gt_pre_text
|
style_text = gt_pre_text
|
||||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||||
@ -373,25 +298,22 @@ def triplet_text_sampling_with_context_list(
|
|||||||
min_len_style: Optional[int] = 80,
|
min_len_style: Optional[int] = 80,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""This function generates a triplet of
|
"""This function generates a triplet of
|
||||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
(pre_text, style_text, ref_text). The pre_text is either the preceding text
|
||||||
should always match, whereas the style of pre_text is arbitrary.
|
or a list of words (context words + distractors).
|
||||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
The style of style_text and ref_text should **always** match, whereas
|
||||||
text and pre_text are referred to as text and pre_text.
|
the style of pre_text is arbitrary.
|
||||||
The following three tuples are all valid:
|
Suppose we have 2 different transforms A,B, and the preceding text is
|
||||||
|
referred to as pre_text. The following three tuples are all valid:
|
||||||
|
|
||||||
(A(pre_text), B(style_text), B(text))
|
(A(pre_text), A(style_text), A(ref_text))
|
||||||
(A(pre_text), C(style_text), C(text))
|
(A(pre_text), B(style_text), B(ref_text))
|
||||||
(A(pre_text), A(style_text), A(text))
|
(A(pre_text), A(style_text), A(ref_text))
|
||||||
...
|
(B(pre_text), B(style_text), B(ref_text))
|
||||||
|
|
||||||
If transforms is not given, the following pre-defined transforms
|
If transforms is not given, the following pre-defined transforms
|
||||||
are available:
|
are available:
|
||||||
0: original (normal case, with punc)
|
0: original (mixed-cased, with punc)
|
||||||
1: recog (upper, no punc)
|
1: upper_only_alpha (upper-cased, no punc)
|
||||||
2: upper_only_alpha (upper, no punc)
|
|
||||||
3: lower_only_alpha (lower, no punc)
|
|
||||||
4: upper_all (upper, with punc)
|
|
||||||
5: lower_all (lower, with punc)
|
|
||||||
|
|
||||||
When the transform of text and pre_text match, we can use the whole
|
When the transform of text and pre_text match, we can use the whole
|
||||||
pre_text as the prompt text.
|
pre_text as the prompt text.
|
||||||
@ -403,15 +325,21 @@ def triplet_text_sampling_with_context_list(
|
|||||||
pre_texts (List[str]):
|
pre_texts (List[str]):
|
||||||
A list of pre_texts, whose first item is the groundtruth
|
A list of pre_texts, whose first item is the groundtruth
|
||||||
pre_text from books.
|
pre_text from books.
|
||||||
|
context_list: Optional[str] = None,
|
||||||
|
A list of biasing words separated by space
|
||||||
|
rare_word_list: Optional[str] = None,
|
||||||
|
A list of rare-words separated by space (used as distractors)
|
||||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of ref_text, pre_text, style_text
|
||||||
Returns:
|
Returns:
|
||||||
str: A dictionary
|
str: A dictionary
|
||||||
"""
|
"""
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
assert len(texts) == len(pre_texts)
|
assert len(texts) == len(pre_texts)
|
||||||
assert len(texts) == 2
|
assert len(texts) == 2
|
||||||
|
|
||||||
if context_list is not None:
|
if context_list is not None:
|
||||||
context_list = context_list.lower()
|
context_list = context_list.lower()
|
||||||
|
|
||||||
@ -426,9 +354,13 @@ def triplet_text_sampling_with_context_list(
|
|||||||
lower_only_alpha,
|
lower_only_alpha,
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
]
|
]
|
||||||
|
|
||||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
sampling_weight = [
|
||||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
0.7,
|
||||||
|
0.3,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
] # Mixed-punc should have the largest sampling prob
|
||||||
|
|
||||||
total_transforms = len(transforms) # do not use the recognized trans
|
total_transforms = len(transforms) # do not use the recognized trans
|
||||||
|
|
||||||
@ -446,11 +378,10 @@ def triplet_text_sampling_with_context_list(
|
|||||||
pre_text = transforms[i_pre_text](pre_text)
|
pre_text = transforms[i_pre_text](pre_text)
|
||||||
|
|
||||||
if i_text == i_pre_text:
|
if i_text == i_pre_text:
|
||||||
style_text = gt_pre_text
|
|
||||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||||
else:
|
else:
|
||||||
# get the pre_text of same style as text
|
# get the pre_text of same style as text
|
||||||
# For now, do not do transform to the style text
|
# For now, **don't** do transform to the style text
|
||||||
style_text = gt_pre_text
|
style_text = gt_pre_text
|
||||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||||
@ -471,7 +402,7 @@ def get_pre_text_with_context_list(
|
|||||||
) -> str:
|
) -> str:
|
||||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||||
|
|
||||||
if context_list != "" and context_list is not None:
|
if context_list != "" and context_list is not None:
|
||||||
v = random.random()
|
v = random.random()
|
||||||
if v < 0.5:
|
if v < 0.5:
|
||||||
@ -489,14 +420,14 @@ def get_pre_text_with_context_list(
|
|||||||
pre_text = " ".join(pre_text)
|
pre_text = " ".join(pre_text)
|
||||||
elif v < 0.7:
|
elif v < 0.7:
|
||||||
splitted = text.split()
|
splitted = text.split()
|
||||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||||
i = random.randint(1, min(len(splitted), 20))
|
i = random.randint(1, min(len(splitted), 20))
|
||||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||||
num_distractors = random.randint(0,70)
|
num_distractors = random.randint(0, 70)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
splitted += distractors
|
splitted += distractors
|
||||||
random.shuffle(splitted) # shuffle the list
|
random.shuffle(splitted) # shuffle the list
|
||||||
pre_text = " ".join(splitted)
|
pre_text = " ".join(splitted)
|
||||||
else:
|
else:
|
||||||
pre_text = pre_text
|
pre_text = pre_text
|
||||||
@ -504,21 +435,21 @@ def get_pre_text_with_context_list(
|
|||||||
v = random.random()
|
v = random.random()
|
||||||
if v < 0.1:
|
if v < 0.1:
|
||||||
splitted = text.split()
|
splitted = text.split()
|
||||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||||
i = random.randint(1, min(len(splitted), 20))
|
i = random.randint(1, min(len(splitted), 20))
|
||||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||||
pre_text = " ".join(splitted)
|
pre_text = " ".join(splitted)
|
||||||
num_distractors = random.randint(0,70)
|
num_distractors = random.randint(0, 70)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
splitted += distractors
|
splitted += distractors
|
||||||
random.shuffle(splitted) # shuffle the list
|
random.shuffle(splitted) # shuffle the list
|
||||||
elif v < 0.2:
|
elif v < 0.2:
|
||||||
# full distractors
|
# full distractors
|
||||||
num_distractors = random.randint(5, 100)
|
num_distractors = random.randint(5, 100)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
pre_text = " ".join(distractors)
|
pre_text = " ".join(distractors)
|
||||||
|
|
||||||
elif v < 0.3:
|
elif v < 0.3:
|
||||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||||
else:
|
else:
|
||||||
@ -527,20 +458,19 @@ def get_pre_text_with_context_list(
|
|||||||
return pre_text
|
return pre_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_pre_text_with_context_list2(
|
def get_pre_text_with_context_list2(
|
||||||
text: str,
|
text: str,
|
||||||
pre_text: str,
|
pre_text: str,
|
||||||
context_list: str,
|
context_list: str,
|
||||||
rare_words_list: List[str] = None,
|
rare_words_list: List[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
# Get the pre_text, either the ground truth preceding text or
|
||||||
|
# a list of words consisting of biasing words and distrators
|
||||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||||
|
|
||||||
if context_list != "" and context_list is not None:
|
if context_list != "" and context_list is not None:
|
||||||
v = random.random()
|
v = random.random()
|
||||||
if v < 0.4:
|
if v < 0.4:
|
||||||
# correct + distractors
|
|
||||||
# sample distractors
|
# sample distractors
|
||||||
num_distractors = random.randint(50, 100)
|
num_distractors = random.randint(50, 100)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
@ -554,14 +484,16 @@ def get_pre_text_with_context_list2(
|
|||||||
pre_text = " ".join(pre_text)
|
pre_text = " ".join(pre_text)
|
||||||
elif v < 0.55:
|
elif v < 0.55:
|
||||||
splitted = text.split()
|
splitted = text.split()
|
||||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
sampling_weights = [
|
||||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
len(w) ** 1.2 for w in splitted
|
||||||
|
] # longer words with higher weights
|
||||||
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||||
i = random.randint(1, min(len(splitted), 20))
|
i = random.randint(1, min(len(splitted), 20))
|
||||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||||
num_distractors = random.randint(50,100)
|
num_distractors = random.randint(50, 100)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
splitted += distractors
|
splitted += distractors
|
||||||
random.shuffle(splitted) # shuffle the list
|
random.shuffle(splitted) # shuffle the list
|
||||||
pre_text = " ".join(splitted)
|
pre_text = " ".join(splitted)
|
||||||
else:
|
else:
|
||||||
pre_text = pre_text
|
pre_text = pre_text
|
||||||
@ -569,21 +501,20 @@ def get_pre_text_with_context_list2(
|
|||||||
v = random.random()
|
v = random.random()
|
||||||
if v < 0.3:
|
if v < 0.3:
|
||||||
splitted = text.split()
|
splitted = text.split()
|
||||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||||
i = random.randint(1, min(len(splitted), 20))
|
i = random.randint(1, min(len(splitted), 20))
|
||||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||||
pre_text = " ".join(splitted)
|
pre_text = " ".join(splitted)
|
||||||
num_distractors = random.randint(50,100)
|
num_distractors = random.randint(50, 100)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
splitted += distractors
|
splitted += distractors
|
||||||
random.shuffle(splitted) # shuffle the list
|
random.shuffle(splitted) # shuffle the list
|
||||||
elif v < 0.4:
|
elif v < 0.4:
|
||||||
# full distractors
|
# full distractors
|
||||||
num_distractors = random.randint(5, 100)
|
num_distractors = random.randint(5, 100)
|
||||||
distractors = random.sample(rare_words_list, num_distractors)
|
distractors = random.sample(rare_words_list, num_distractors)
|
||||||
pre_text = " ".join(distractors)
|
pre_text = " ".join(distractors)
|
||||||
|
|
||||||
elif v < 0.6:
|
elif v < 0.6:
|
||||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||||
else:
|
else:
|
||||||
@ -592,160 +523,6 @@ def get_pre_text_with_context_list2(
|
|||||||
return pre_text
|
return pre_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def joint_triplet_text_sampling(
|
|
||||||
texts: List[str],
|
|
||||||
pre_texts: List[str],
|
|
||||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
|
||||||
min_len_style: Optional[int] = 80,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""This function generates a triplet of
|
|
||||||
(pre_text, style_text, ref_text). The style of pre_text, style_text
|
|
||||||
and ref_text should **always** match.
|
|
||||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
|
||||||
text and pre_text are referred to as text and pre_text.
|
|
||||||
The following three tuples are all valid:
|
|
||||||
|
|
||||||
(A(pre_text), A(style_text), A(text))
|
|
||||||
(B(pre_text), B(style_text), B(text))
|
|
||||||
(C(pre_text), C(style_text), C(text))
|
|
||||||
...
|
|
||||||
|
|
||||||
If transforms is not given, the following pre-defined transforms
|
|
||||||
are available:
|
|
||||||
0: original (normal case, with punc)
|
|
||||||
1: recog (upper, no punc)
|
|
||||||
2: upper_only_alpha (upper, no punc)
|
|
||||||
3: lower_only_alpha (lower, no punc)
|
|
||||||
4: upper_all (upper, with punc)
|
|
||||||
5: lower_all (lower, with punc)
|
|
||||||
|
|
||||||
When the transform of text and pre_text match, we can use the whole
|
|
||||||
pre_text as the prompt text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]):
|
|
||||||
A list of ref_texts whose first item is the ground truth
|
|
||||||
text from books.
|
|
||||||
pre_texts (List[str]):
|
|
||||||
A list of pre_texts, whose first item is the groundtruth
|
|
||||||
pre_text from books.
|
|
||||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A dictionary
|
|
||||||
"""
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
assert len(texts) == len(pre_texts)
|
|
||||||
assert len(texts) == 2
|
|
||||||
|
|
||||||
# we assume the first item to be ground truth
|
|
||||||
gt_text = texts[0]
|
|
||||||
gt_pre_text = pre_texts[0]
|
|
||||||
|
|
||||||
if transforms is None:
|
|
||||||
transforms = [
|
|
||||||
lambda x: x, # return it self
|
|
||||||
upper_only_alpha,
|
|
||||||
lower_only_alpha,
|
|
||||||
lower_all_char,
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
|
||||||
|
|
||||||
total_transforms = len(transforms) # do not use the recognized trans
|
|
||||||
|
|
||||||
# Select a transformation randomly
|
|
||||||
i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
|
|
||||||
|
|
||||||
# get the normalized text and pre_text
|
|
||||||
text = transforms[i_text](gt_text)
|
|
||||||
pre_text = transforms[i_text](gt_pre_text)
|
|
||||||
|
|
||||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"text": train_text_normalization(text),
|
|
||||||
"pre_text": train_text_normalization(pre_text),
|
|
||||||
"style_text": train_text_normalization(style_text),
|
|
||||||
"transform_ids": i_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def triplet_style_text_sampling(
|
|
||||||
texts: List[str],
|
|
||||||
pre_texts: List[str],
|
|
||||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
|
||||||
min_len_style: Optional[int] = 80,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""This function generates a triplet of
|
|
||||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
|
||||||
should always match, whereas the style of pre_text is fixed to mixed-trans.
|
|
||||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
|
||||||
text and pre_text are referred to as text and pre_text.
|
|
||||||
The following three tuples are all valid:
|
|
||||||
|
|
||||||
(gt_pre_text, B(style_text), B(text))
|
|
||||||
(gt_pre_text, C(style_text), C(text))
|
|
||||||
(gt_pre_text, A(style_text), A(text))
|
|
||||||
...
|
|
||||||
|
|
||||||
If transforms is not given, the following pre-defined transforms
|
|
||||||
are available:
|
|
||||||
0: original (normal case, with punc)
|
|
||||||
1: recog (upper, no punc)
|
|
||||||
2: upper_only_alpha (upper, no punc)
|
|
||||||
3: lower_only_alpha (lower, no punc)
|
|
||||||
4: upper_all (upper, with punc)
|
|
||||||
5: lower_all (lower, with punc)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (List[str]):
|
|
||||||
A list of ref_texts whose first item is the ground truth
|
|
||||||
text from books.
|
|
||||||
pre_texts (List[str]):
|
|
||||||
A list of pre_texts, whose first item is the groundtruth
|
|
||||||
pre_text from books.
|
|
||||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A dictionary
|
|
||||||
"""
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
assert len(texts) == len(pre_texts)
|
|
||||||
assert len(texts) == 2
|
|
||||||
|
|
||||||
# we assume the first item to be ground truth
|
|
||||||
gt_text = texts[0]
|
|
||||||
gt_pre_text = pre_texts[0]
|
|
||||||
|
|
||||||
if transforms is None:
|
|
||||||
transforms = [
|
|
||||||
lambda x: x, # return it self
|
|
||||||
upper_only_alpha,
|
|
||||||
lower_only_alpha,
|
|
||||||
lower_all_char,
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
|
||||||
total_transforms = len(transforms) # do not use the recognized trans
|
|
||||||
|
|
||||||
# Select a transformation randomly
|
|
||||||
t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
|
|
||||||
|
|
||||||
# get the normalized text
|
|
||||||
text = transforms[t_id](gt_text)
|
|
||||||
# get the original un-processed style text
|
|
||||||
style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"text": train_text_normalization(text),
|
|
||||||
"pre_text": train_text_normalization(gt_pre_text),
|
|
||||||
"style_text": train_text_normalization(style_text),
|
|
||||||
"transform_ids": t_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def naive_triplet_text_sampling(
|
def naive_triplet_text_sampling(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
pre_texts: List[str],
|
pre_texts: List[str],
|
||||||
@ -753,18 +530,13 @@ def naive_triplet_text_sampling(
|
|||||||
rare_word_list: List[str] = None,
|
rare_word_list: List[str] = None,
|
||||||
min_len_style: Optional[int] = 120,
|
min_len_style: Optional[int] = 120,
|
||||||
):
|
):
|
||||||
|
# The most simplest text sampling function, used only for
|
||||||
|
# evaluation, use a fixed sentence as the style text
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": train_text_normalization(texts[0]),
|
"text": train_text_normalization(texts[0]),
|
||||||
"pre_text": train_text_normalization(pre_texts[0]),
|
"pre_text": train_text_normalization(pre_texts[0]),
|
||||||
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
|
|
||||||
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
|
||||||
#"pre_text": "Hello, my friend. "*50,
|
|
||||||
#"style_text": train_text_normalization(pre_texts[0][:200]),
|
|
||||||
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
||||||
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
|
||||||
#"style_text": "Mixed-case English transcription, with punctuation.",
|
|
||||||
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
|
|
||||||
"transform_ids": 0,
|
"transform_ids": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,11 +547,11 @@ def random_shuffle_subset(
|
|||||||
p_mask: float = 0.05,
|
p_mask: float = 0.05,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Randomly shuffle the subset by probability p, which means that p% of the samples
|
Randomly shuffle the subset by probability `p`, which means that p% of the samples
|
||||||
in the original batch are shuffled, the others are kept in the original order.
|
in the original batch are shuffled, the others are kept in the original order.
|
||||||
|
|
||||||
With a probability of p_mask, replace the original string with an empty string.
|
With a probability of `p_mask`, replace the original string with an empty string.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_to_shuffle = int(len(data) * p)
|
num_to_shuffle = int(len(data) * p)
|
||||||
@ -789,9 +561,9 @@ def random_shuffle_subset(
|
|||||||
|
|
||||||
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
||||||
data[id] = item
|
data[id] = item
|
||||||
|
|
||||||
# Randomly mask a proportion of the data to empty string
|
# Randomly mask a proportion of the data to empty string
|
||||||
if p_mask > 0:
|
if p_mask > 0:
|
||||||
for i in range(len(data)):
|
for i in range(len(data)):
|
||||||
if random.random() < p_mask:
|
if random.random() < p_mask:
|
||||||
data[i] = ""
|
data[i] = ""
|
||||||
@ -809,16 +581,6 @@ if __name__ == "__main__":
|
|||||||
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
||||||
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
||||||
]
|
]
|
||||||
# for i in range(10):
|
for i in range(10):
|
||||||
# print(f"Run: {i}")
|
print(f"Run: {i}")
|
||||||
# print(triplet_text_sampling(texts, pre_texts))
|
print(triplet_text_sampling(texts, pre_texts))
|
||||||
|
|
||||||
import time
|
|
||||||
start = time.time()
|
|
||||||
data = [str(i) for i in range(30)]
|
|
||||||
random.shuffle(data)
|
|
||||||
print(data)
|
|
||||||
for i in range(1):
|
|
||||||
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
|
|
||||||
print(shuffled)
|
|
||||||
print((time.time() - start)/100)
|
|
||||||
|
@ -1,477 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from lhotse import load_manifest, Fbank
|
|
||||||
|
|
||||||
from beam_search import (
|
|
||||||
beam_search,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from text_normalization import (
|
|
||||||
ref_text_normalization,
|
|
||||||
remove_non_alphabetic,
|
|
||||||
upper_only_alpha,
|
|
||||||
upper_all_char,
|
|
||||||
lower_all_char,
|
|
||||||
lower_only_alpha,
|
|
||||||
train_text_normalization,
|
|
||||||
)
|
|
||||||
from train_subformer_with_style import (
|
|
||||||
add_model_arguments,
|
|
||||||
get_params,
|
|
||||||
get_tokenizer,
|
|
||||||
get_transducer_model,
|
|
||||||
_encode_text_as_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
|
||||||
average_checkpoints,
|
|
||||||
average_checkpoints_with_averaged_model,
|
|
||||||
find_checkpoints,
|
|
||||||
load_checkpoint,
|
|
||||||
)
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=9,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch' and '--iter'",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="pruned_transducer_stateless7/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--bpe-model",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_500/bpe.model",
|
|
||||||
help="""Path to bpe.model.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--method",
|
|
||||||
type=str,
|
|
||||||
default="greedy_search",
|
|
||||||
help="""Possible values are:
|
|
||||||
- greedy_search
|
|
||||||
- beam_search
|
|
||||||
- modified_beam_search
|
|
||||||
- fast_beam_search
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam-size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
|
|
||||||
help="""This is the manfiest for long audio transcription.
|
|
||||||
It is intended to be sored, i.e first sort by recording ID and then sort by
|
|
||||||
start timestamp"""
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--segment-length",
|
|
||||||
type=float,
|
|
||||||
default=30.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-pre-text",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether use pre-text when decoding the current chunk"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-style-prompt",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Use style prompt when evaluation"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pre-text-transform",
|
|
||||||
type=str,
|
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
|
||||||
default="mixed-punc",
|
|
||||||
help="The style of content prompt, i.e pre_text"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--style-text-transform",
|
|
||||||
type=str,
|
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
|
||||||
default="mixed-punc",
|
|
||||||
help="The style of style prompt, i.e style_text"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-history",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="How many previous chunks to look if using pre-text for decoding"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-gt-pre-text",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether use gt pre text when using content prompt",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--post-normalization",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
|
||||||
"""Apply transform to a list of text. By default, the text are in
|
|
||||||
ground truth format, i.e mixed-punc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (List[str]): Input text string
|
|
||||||
transform (str): Transform to be applied
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: _description_
|
|
||||||
"""
|
|
||||||
if transform == "mixed-punc":
|
|
||||||
return text
|
|
||||||
elif transform == "upper-no-punc":
|
|
||||||
return [upper_only_alpha(s) for s in text]
|
|
||||||
elif transform == "lower-no-punc":
|
|
||||||
return [lower_only_alpha(s) for s in text]
|
|
||||||
elif transform == "lower-punc":
|
|
||||||
return [lower_all_char(s) for s in text]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(params.bpe_model)
|
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
|
||||||
params.res_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
if params.iter > 0:
|
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
|
||||||
else:
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
||||||
|
|
||||||
if "beam_search" in params.method:
|
|
||||||
params.suffix += (
|
|
||||||
f"-{params.method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_pre_text:
|
|
||||||
if params.use_gt_pre_text:
|
|
||||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
|
||||||
else:
|
|
||||||
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
|
||||||
|
|
||||||
|
|
||||||
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
|
|
||||||
logging.info("Decoding started")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", 0)
|
|
||||||
|
|
||||||
logging.info(f"device: {device}")
|
|
||||||
|
|
||||||
logging.info("Creating model")
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
text_sp = spm.SentencePieceProcessor()
|
|
||||||
text_sp.load(params.text_encoder_bpe_model)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
if params.iter > 0:
|
|
||||||
filenames = find_checkpoints(
|
|
||||||
params.exp_dir, iteration=-params.iter
|
|
||||||
)[: params.avg + 1]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg + 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
filename_start = filenames[-1]
|
|
||||||
filename_end = filenames[0]
|
|
||||||
logging.info(
|
|
||||||
"Calculating the averaged model over iteration checkpoints"
|
|
||||||
f" from {filename_start} (excluded) to {filename_end}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert params.avg > 0, params.avg
|
|
||||||
start = params.epoch - params.avg
|
|
||||||
assert start >= 1, start
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
|
||||||
model.to(device)
|
|
||||||
model.load_state_dict(
|
|
||||||
average_checkpoints_with_averaged_model(
|
|
||||||
filename_start=filename_start,
|
|
||||||
filename_end=filename_end,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
model.device = device
|
|
||||||
|
|
||||||
# load manifest
|
|
||||||
manifest = load_manifest(params.manifest_dir)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
last_recording = ""
|
|
||||||
last_end = -1
|
|
||||||
history = []
|
|
||||||
num_pre_texts = []
|
|
||||||
|
|
||||||
for cut in manifest:
|
|
||||||
if cut.has_features:
|
|
||||||
feat = cut.load_features()
|
|
||||||
feat_lens = cut.num_frames
|
|
||||||
else:
|
|
||||||
feat = cut.compute_features(extractor=Fbank())
|
|
||||||
feat_lens = feat.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
cur_recording = cut.recording.id
|
|
||||||
|
|
||||||
if cur_recording != last_recording:
|
|
||||||
last_recording = cur_recording
|
|
||||||
history = [] # clean history
|
|
||||||
last_end = -1
|
|
||||||
logging.info(f"Moving on to the next recording")
|
|
||||||
else:
|
|
||||||
if cut.start < last_end - 0.2: # overlap exits
|
|
||||||
logging.warning(f"An overlap exists between current cut and last cut")
|
|
||||||
logging.warning("Skipping this cut!")
|
|
||||||
continue
|
|
||||||
if cut.start > last_end + 10:
|
|
||||||
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
|
|
||||||
|
|
||||||
# prepare input
|
|
||||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
|
||||||
x_lens = torch.tensor([feat_lens,], device=device)
|
|
||||||
|
|
||||||
if params.use_pre_text:
|
|
||||||
if params.num_history > 0:
|
|
||||||
pre_texts = history[-params.num_history:]
|
|
||||||
else:
|
|
||||||
pre_texts = []
|
|
||||||
assert len(pre_texts) <= params.num_history
|
|
||||||
num_pre_texts.append(len(pre_texts))
|
|
||||||
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
|
||||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
|
||||||
style_texts = [fixed_sentence]
|
|
||||||
|
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
|
||||||
if params.use_style_prompt:
|
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
|
||||||
|
|
||||||
# encode pre_text
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
|
|
||||||
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
|
|
||||||
pre_texts=pre_texts,
|
|
||||||
style_texts=style_texts,
|
|
||||||
bpe_model=text_sp,
|
|
||||||
device=device,
|
|
||||||
max_tokens=1500,
|
|
||||||
)
|
|
||||||
if params.num_history > 5:
|
|
||||||
logging.info(f"Shape of encoded texts: {pre_texts.shape} ")
|
|
||||||
|
|
||||||
memory, memory_key_padding_mask = model.encode_text(
|
|
||||||
text=pre_texts,
|
|
||||||
style_lens=style_text_lens,
|
|
||||||
text_lens=pre_texts_lens,
|
|
||||||
) # (T,B,C)
|
|
||||||
else:
|
|
||||||
memory = None
|
|
||||||
memory_key_padding_mask = None
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
encoder_out, encoder_out_lens = model.encode_audio(
|
|
||||||
feature=x,
|
|
||||||
feature_lens=x_lens,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.method == "greedy_search":
|
|
||||||
hyp_tokens = greedy_search_batch(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
|
||||||
elif params.method == "modified_beam_search":
|
|
||||||
hyp_tokens = modified_beam_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
|
||||||
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
|
|
||||||
|
|
||||||
# extend the history, the history here is in original format
|
|
||||||
if params.use_gt_pre_text:
|
|
||||||
history.append(ref_text)
|
|
||||||
else:
|
|
||||||
history.append(hyp)
|
|
||||||
last_end = cut.end # update the last end timestamp
|
|
||||||
|
|
||||||
# append the current decoding result
|
|
||||||
hyp = hyp.split()
|
|
||||||
ref = ref_text.split()
|
|
||||||
results.append((cut.id, ref, hyp))
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
if count % 100 == 0:
|
|
||||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
|
||||||
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
|
|
||||||
logging.info(f"A total of {count} cuts")
|
|
||||||
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
|
|
||||||
|
|
||||||
results = sorted(results)
|
|
||||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
||||||
|
|
||||||
if params.post_normalization:
|
|
||||||
params.suffix += "-post-normalization"
|
|
||||||
|
|
||||||
new_res = []
|
|
||||||
for item in results:
|
|
||||||
id, ref, hyp = item
|
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
|
||||||
ref = upper_only_alpha(" ".join(ref)).split()
|
|
||||||
new_res.append((id,ref,hyp))
|
|
||||||
|
|
||||||
new_res = sorted(new_res)
|
|
||||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
|
||||||
store_transcripts(filename=recog_path, texts=new_res)
|
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
|
||||||
with open(errs_filename, "w") as f:
|
|
||||||
wer = write_error_stats(
|
|
||||||
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
||||||
if __name__=="__main__":
|
|
||||||
main()
|
|
Loading…
x
Reference in New Issue
Block a user