mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add documentation to different text sampling function
This commit is contained in:
parent
6579800720
commit
93461fb77e
@ -1,23 +1,38 @@
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
import random
|
||||
import numpy as np
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import compute_num_frames, ifnone
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
from text_normalization import (
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
remove_non_alphabetic,
|
||||
train_text_normalization,
|
||||
upper_all_char,
|
||||
upper_only_alpha,
|
||||
)
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
|
||||
class PromptASRDataset(torch.utils.data.Dataset):
|
||||
@ -34,7 +49,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||
rare_word_list: Optional[List[str]] = None
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||
@ -63,9 +78,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
self.text_sampling_func = text_sampling_func
|
||||
self.rare_word_list = rare_word_list
|
||||
|
||||
def __getitem__(
|
||||
self, cuts: CutSet
|
||||
) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
"""
|
||||
Return a new batch, with the batch size automatically determined using the constraints
|
||||
of max_frames and max_cuts.
|
||||
@ -112,15 +125,15 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
self.text_sampling_func(
|
||||
texts=supervision.texts,
|
||||
pre_texts=supervision.pre_texts,
|
||||
context_list=supervision.context_list if "context_list" in supervision.custom else None,
|
||||
context_list=supervision.context_list
|
||||
if "context_list" in supervision.custom
|
||||
else None,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
if self.text_sampling_func is not None
|
||||
else {
|
||||
"text": train_text_normalization(supervision.texts[0]),
|
||||
"pre_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
"pre_text": train_text_normalization(supervision.pre_texts[0]),
|
||||
"style_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
@ -192,27 +205,22 @@ def triplet_text_sampling(
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
) -> Dict[str, str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
should **always** match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 2 different transforms A,B, and the preceding text is
|
||||
referred to as pre_text. The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(A(pre_text), B(style_text), B(ref_text))
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(B(pre_text), B(style_text), B(ref_text))
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
0: original (mixed-cased, with punc)
|
||||
1: upper_only_alpha (upper-cased, no punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
@ -224,12 +232,15 @@ def triplet_text_sampling(
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
context_list: Optional[str] = None,
|
||||
A list of biasing words separated by space
|
||||
rare_word_list: Optional[str] = None,
|
||||
A list of rare-words separated by space (used as distractors)
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
A dictionary of ref_text, pre_text, style_text
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
@ -244,13 +255,17 @@ def triplet_text_sampling(
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
sampling_weight = [
|
||||
0.7,
|
||||
0.3,
|
||||
0.0,
|
||||
0.0,
|
||||
] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
# Randomly sample transforms
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
@ -261,97 +276,7 @@ def triplet_text_sampling(
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def triplet_text_sampling2(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
# For now, **don't** do transform to the style text, because we do it after the dataloader
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
@ -373,25 +298,22 @@ def triplet_text_sampling_with_context_list(
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
(pre_text, style_text, ref_text). The pre_text is either the preceding text
|
||||
or a list of words (context words + distractors).
|
||||
The style of style_text and ref_text should **always** match, whereas
|
||||
the style of pre_text is arbitrary.
|
||||
Suppose we have 2 different transforms A,B, and the preceding text is
|
||||
referred to as pre_text. The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(A(pre_text), B(style_text), B(ref_text))
|
||||
(A(pre_text), A(style_text), A(ref_text))
|
||||
(B(pre_text), B(style_text), B(ref_text))
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
0: original (mixed-cased, with punc)
|
||||
1: upper_only_alpha (upper-cased, no punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
@ -403,15 +325,21 @@ def triplet_text_sampling_with_context_list(
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
context_list: Optional[str] = None,
|
||||
A list of biasing words separated by space
|
||||
rare_word_list: Optional[str] = None,
|
||||
A list of rare-words separated by space (used as distractors)
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
A dictionary of ref_text, pre_text, style_text
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
|
||||
if context_list is not None:
|
||||
context_list = context_list.lower()
|
||||
|
||||
@ -426,9 +354,13 @@ def triplet_text_sampling_with_context_list(
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
sampling_weight = [
|
||||
0.7,
|
||||
0.3,
|
||||
0.0,
|
||||
0.0,
|
||||
] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
@ -446,11 +378,10 @@ def triplet_text_sampling_with_context_list(
|
||||
pre_text = transforms[i_pre_text](pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = gt_pre_text
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
# For now, **don't** do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
@ -471,7 +402,7 @@ def get_pre_text_with_context_list(
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
@ -489,14 +420,14 @@ def get_pre_text_with_context_list(
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.7:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(0,70)
|
||||
num_distractors = random.randint(0, 70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
@ -504,21 +435,21 @@ def get_pre_text_with_context_list(
|
||||
v = random.random()
|
||||
if v < 0.1:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(0,70)
|
||||
num_distractors = random.randint(0, 70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
@ -527,20 +458,19 @@ def get_pre_text_with_context_list(
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def get_pre_text_with_context_list2(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# Get the pre_text, either the ground truth preceding text or
|
||||
# a list of words consisting of biasing words and distrators
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.4:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
@ -554,14 +484,16 @@ def get_pre_text_with_context_list2(
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.55:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
sampling_weights = [
|
||||
len(w) ** 1.2 for w in splitted
|
||||
] # longer words with higher weights
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(50,100)
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
@ -569,21 +501,20 @@ def get_pre_text_with_context_list2(
|
||||
v = random.random()
|
||||
if v < 0.3:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
||||
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(50,100)
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.4:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
pre_text = " ".join(distractors)
|
||||
elif v < 0.6:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
@ -592,160 +523,6 @@ def get_pre_text_with_context_list2(
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def joint_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of pre_text, style_text
|
||||
and ref_text should **always** match.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
(B(pre_text), B(style_text), B(text))
|
||||
(C(pre_text), C(style_text), C(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_text](gt_pre_text)
|
||||
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def triplet_style_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is fixed to mixed-trans.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(gt_pre_text, B(style_text), B(text))
|
||||
(gt_pre_text, C(style_text), C(text))
|
||||
(gt_pre_text, A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
|
||||
|
||||
# get the normalized text
|
||||
text = transforms[t_id](gt_text)
|
||||
# get the original un-processed style text
|
||||
style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(gt_pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": t_id,
|
||||
}
|
||||
|
||||
|
||||
def naive_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
@ -753,18 +530,13 @@ def naive_triplet_text_sampling(
|
||||
rare_word_list: List[str] = None,
|
||||
min_len_style: Optional[int] = 120,
|
||||
):
|
||||
# The most simplest text sampling function, used only for
|
||||
# evaluation, use a fixed sentence as the style text
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(texts[0]),
|
||||
"pre_text": train_text_normalization(pre_texts[0]),
|
||||
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
|
||||
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
#"pre_text": "Hello, my friend. "*50,
|
||||
#"style_text": train_text_normalization(pre_texts[0][:200]),
|
||||
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
||||
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
#"style_text": "Mixed-case English transcription, with punctuation.",
|
||||
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
|
||||
@ -775,11 +547,11 @@ def random_shuffle_subset(
|
||||
p_mask: float = 0.05,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Randomly shuffle the subset by probability p, which means that p% of the samples
|
||||
Randomly shuffle the subset by probability `p`, which means that p% of the samples
|
||||
in the original batch are shuffled, the others are kept in the original order.
|
||||
|
||||
With a probability of p_mask, replace the original string with an empty string.
|
||||
|
||||
|
||||
With a probability of `p_mask`, replace the original string with an empty string.
|
||||
|
||||
"""
|
||||
|
||||
num_to_shuffle = int(len(data) * p)
|
||||
@ -789,9 +561,9 @@ def random_shuffle_subset(
|
||||
|
||||
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
||||
data[id] = item
|
||||
|
||||
|
||||
# Randomly mask a proportion of the data to empty string
|
||||
if p_mask > 0:
|
||||
if p_mask > 0:
|
||||
for i in range(len(data)):
|
||||
if random.random() < p_mask:
|
||||
data[i] = ""
|
||||
@ -809,16 +581,6 @@ if __name__ == "__main__":
|
||||
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
||||
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
||||
]
|
||||
# for i in range(10):
|
||||
# print(f"Run: {i}")
|
||||
# print(triplet_text_sampling(texts, pre_texts))
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
data = [str(i) for i in range(30)]
|
||||
random.shuffle(data)
|
||||
print(data)
|
||||
for i in range(1):
|
||||
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
|
||||
print(shuffled)
|
||||
print((time.time() - start)/100)
|
||||
for i in range(10):
|
||||
print(f"Run: {i}")
|
||||
print(triplet_text_sampling(texts, pre_texts))
|
||||
|
@ -1,477 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from lhotse import load_manifest, Fbank
|
||||
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from text_normalization import (
|
||||
ref_text_normalization,
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
lower_only_alpha,
|
||||
train_text_normalization,
|
||||
)
|
||||
from train_subformer_with_style import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_tokenizer,
|
||||
get_transducer_model,
|
||||
_encode_text_as_tokens,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/long_audios/long_audio_pomonastravels_combined.jsonl.gz",
|
||||
help="""This is the manfiest for long audio transcription.
|
||||
It is intended to be sored, i.e first sort by recording ID and then sort by
|
||||
start timestamp"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segment-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use pre-text when decoding the current chunk"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-history",
|
||||
type=int,
|
||||
default=2,
|
||||
help="How many previous chunks to look if using pre-text for decoding"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-gt-pre-text",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether use gt pre text when using content prompt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
params.res_dir = params.exp_dir / "long_audio_transcribe"
|
||||
params.res_dir.mkdir(exist_ok=True)
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.method:
|
||||
params.suffix += (
|
||||
f"-{params.method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.use_gt_pre_text:
|
||||
params.suffix += f"-use-gt-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
else:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}-history-{params.num_history}"
|
||||
|
||||
|
||||
book_name = params.manifest_dir.split('/')[-1].replace(".jsonl.gz", "")
|
||||
setup_logger(f"{params.res_dir}/log-decode-{book_name}-{params.suffix}", log_level="info")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
text_sp = spm.SentencePieceProcessor()
|
||||
text_sp.load(params.text_encoder_bpe_model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
# load manifest
|
||||
manifest = load_manifest(params.manifest_dir)
|
||||
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
last_recording = ""
|
||||
last_end = -1
|
||||
history = []
|
||||
num_pre_texts = []
|
||||
|
||||
for cut in manifest:
|
||||
if cut.has_features:
|
||||
feat = cut.load_features()
|
||||
feat_lens = cut.num_frames
|
||||
else:
|
||||
feat = cut.compute_features(extractor=Fbank())
|
||||
feat_lens = feat.shape[0]
|
||||
|
||||
|
||||
cur_recording = cut.recording.id
|
||||
|
||||
if cur_recording != last_recording:
|
||||
last_recording = cur_recording
|
||||
history = [] # clean history
|
||||
last_end = -1
|
||||
logging.info(f"Moving on to the next recording")
|
||||
else:
|
||||
if cut.start < last_end - 0.2: # overlap exits
|
||||
logging.warning(f"An overlap exists between current cut and last cut")
|
||||
logging.warning("Skipping this cut!")
|
||||
continue
|
||||
if cut.start > last_end + 10:
|
||||
logging.warning(f"Large time gap between the current and previous utterance: {cut.start - last_end}.")
|
||||
|
||||
# prepare input
|
||||
x = torch.tensor(feat, device=device).unsqueeze(0)
|
||||
x_lens = torch.tensor([feat_lens,], device=device)
|
||||
|
||||
if params.use_pre_text:
|
||||
if params.num_history > 0:
|
||||
pre_texts = history[-params.num_history:]
|
||||
else:
|
||||
pre_texts = []
|
||||
assert len(pre_texts) <= params.num_history
|
||||
num_pre_texts.append(len(pre_texts))
|
||||
pre_texts = [train_text_normalization(" ".join(pre_texts))]
|
||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it is fully not related."
|
||||
style_texts = [fixed_sentence]
|
||||
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
# encode pre_text
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
pre_texts, pre_texts_lens, style_text_lens = _encode_text_as_tokens(
|
||||
pre_texts=pre_texts,
|
||||
style_texts=style_texts,
|
||||
bpe_model=text_sp,
|
||||
device=device,
|
||||
max_tokens=1500,
|
||||
)
|
||||
if params.num_history > 5:
|
||||
logging.info(f"Shape of encoded texts: {pre_texts.shape} ")
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
text=pre_texts,
|
||||
style_lens=style_text_lens,
|
||||
text_lens=pre_texts_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=x,
|
||||
feature_lens=x_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
if params.method == "greedy_search":
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
hyp = sp.decode(hyp_tokens)[0] # in string format
|
||||
ref_text = ref_text_normalization(cut.supervisions[0].texts[0]) # required to match the training
|
||||
|
||||
# extend the history, the history here is in original format
|
||||
if params.use_gt_pre_text:
|
||||
history.append(ref_text)
|
||||
else:
|
||||
history.append(hyp)
|
||||
last_end = cut.end # update the last end timestamp
|
||||
|
||||
# append the current decoding result
|
||||
hyp = hyp.split()
|
||||
ref = ref_text.split()
|
||||
results.append((cut.id, ref, hyp))
|
||||
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
logging.info(f"Cuts processed until now: {count}/{len(manifest)}")
|
||||
logging.info(f"Averaged context numbers of last 100 samples is: {sum(num_pre_texts[-100:])/100}")
|
||||
logging.info(f"A total of {count} cuts")
|
||||
logging.info(f"Averaged context numbers of whole set is: {sum(num_pre_texts)/len(num_pre_texts)}")
|
||||
|
||||
results = sorted(results)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", results, enable_log=True, compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = []
|
||||
for item in results:
|
||||
id, ref, hyp = item
|
||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||
ref = upper_only_alpha(" ".join(ref)).split()
|
||||
new_res.append((id,ref,hyp))
|
||||
|
||||
new_res = sorted(new_res)
|
||||
recog_path = params.res_dir / f"recogs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
store_transcripts(filename=recog_path, texts=new_res)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
errs_filename = params.res_dir / f"errs-long-audio-{params.method}-{params.suffix}-post-normalization.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"long-audio-{params.method}", new_res, enable_log=True, compute_CER=False,
|
||||
)
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
if __name__=="__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user