mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
further clean up
This commit is contained in:
parent
ae2c7c73f6
commit
a0fe6bcd0d
@ -11,6 +11,7 @@ from lhotse.utils import compute_num_frames, ifnone
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
from text_normalization import (
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
upper_all_char,
|
||||
@ -33,6 +34,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||
rare_word_list: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||
@ -59,6 +61,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
|
||||
# a text sampling function
|
||||
self.text_sampling_func = text_sampling_func
|
||||
self.rare_word_list = rare_word_list
|
||||
|
||||
def __getitem__(
|
||||
self, cuts: CutSet
|
||||
@ -109,6 +112,8 @@ class PromptASRDataset(torch.utils.data.Dataset):
|
||||
self.text_sampling_func(
|
||||
texts=supervision.texts,
|
||||
pre_texts=supervision.pre_texts,
|
||||
context_list=supervision.context_list if "context_list" in supervision.custom else None,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
if self.text_sampling_func is not None
|
||||
else {
|
||||
@ -183,6 +188,8 @@ def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
|
||||
def triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
@ -238,7 +245,8 @@ def triplet_text_sampling(
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0] # Mixed-punc should have the largest sampling prob
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
@ -266,7 +274,8 @@ def triplet_text_sampling(
|
||||
}
|
||||
|
||||
|
||||
def multi_ref_text_triplet_text_sampling(
|
||||
|
||||
def triplet_text_sampling2(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
@ -310,14 +319,12 @@ def multi_ref_text_triplet_text_sampling(
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
assert len(texts) == 3
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth, the third item to be the
|
||||
# decoding results with prompts
|
||||
if random.random() < 0.5:
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
else:
|
||||
gt_text = texts[2] # decoding res with prompt
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
@ -328,7 +335,8 @@ def multi_ref_text_triplet_text_sampling(
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
@ -356,6 +364,234 @@ def multi_ref_text_triplet_text_sampling(
|
||||
}
|
||||
|
||||
|
||||
def triplet_text_sampling_with_context_list(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str,
|
||||
rare_word_list: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
if context_list is not None:
|
||||
context_list = context_list.lower()
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = get_pre_text_with_context_list2(
|
||||
text=gt_pre_text,
|
||||
pre_text=gt_pre_text,
|
||||
context_list=context_list,
|
||||
rare_words_list=rare_word_list,
|
||||
)
|
||||
pre_text = transforms[i_pre_text](pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = gt_pre_text
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def get_pre_text_with_context_list(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(0, 50)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.7:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.1:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def get_pre_text_with_context_list2(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.4:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.55:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(50,100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.3:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(50,100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.4:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.6:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def joint_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
@ -521,7 +757,6 @@ def naive_triplet_text_sampling(
|
||||
return {
|
||||
"text": train_text_normalization(texts[0]),
|
||||
"pre_text": train_text_normalization(pre_texts[0]),
|
||||
#"pre_text": "",
|
||||
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
|
||||
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
#"pre_text": "Hello, my friend. "*50,
|
||||
@ -529,7 +764,7 @@ def naive_triplet_text_sampling(
|
||||
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
||||
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
#"style_text": "Mixed-case English transcription, with punctuation.",
|
||||
#"style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
|
||||
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
|
||||
|
@ -1,824 +0,0 @@
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import compute_num_frames, ifnone
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
from text_normalization import (
|
||||
remove_non_alphabetic,
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
upper_all_char,
|
||||
lower_all_char,
|
||||
train_text_normalization,
|
||||
)
|
||||
|
||||
|
||||
class PromptASRDataset(torch.utils.data.Dataset):
|
||||
"""This is a dataset for Prompt ASR. It supports the following features:
|
||||
1. Select a tuple of (text, pre_text, style_text) randomly from a
|
||||
list of texts as supervisions.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
return_cuts: bool = False,
|
||||
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
||||
rare_word_list: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
||||
for more details.
|
||||
|
||||
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
||||
objects used to create that batch.
|
||||
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
||||
before converting cuts to an input representation (audio/features).
|
||||
Examples: cut concatenation, noise cuts mixing, etc.
|
||||
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
||||
after the cuts are converted to audio/features.
|
||||
Examples: normalization, SpecAugment, etc.
|
||||
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
||||
By default, reads pre-computed features from disk.
|
||||
:param text_sampling_func: Sampling a text as transcription from a list of texts.
|
||||
"""
|
||||
super().__init__()
|
||||
# Initialize the fields
|
||||
self.return_cuts = return_cuts
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.input_transforms = ifnone(input_transforms, [])
|
||||
self.input_strategy = input_strategy
|
||||
|
||||
# a text sampling function
|
||||
self.text_sampling_func = text_sampling_func
|
||||
self.rare_word_list = rare_word_list
|
||||
|
||||
def __getitem__(
|
||||
self, cuts: CutSet
|
||||
) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
"""
|
||||
Return a new batch, with the batch size automatically determined using the constraints
|
||||
of max_frames and max_cuts.
|
||||
"""
|
||||
validate_for_asr(cuts)
|
||||
|
||||
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
|
||||
# the supervision boundaries.
|
||||
for tnfm in self.cut_transforms:
|
||||
cuts = tnfm(cuts)
|
||||
|
||||
# Sort the cuts again after transforms
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||
# Collation performs auto-padding, if necessary.
|
||||
input_tpl = self.input_strategy(cuts)
|
||||
if len(input_tpl) == 3:
|
||||
# An input strategy with fault tolerant audio reading mode.
|
||||
# "cuts" may be a subset of the original "cuts" variable,
|
||||
# that only has cuts for which we succesfully read the audio.
|
||||
inputs, _, cuts = input_tpl
|
||||
else:
|
||||
inputs, _ = input_tpl
|
||||
|
||||
# Get a dict of tensors that encode the positional information about supervisions
|
||||
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||
# "start_frame/sample" and "num_frames/samples".
|
||||
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||
|
||||
# Apply all available transforms on the inputs, i.e. either audio or features.
|
||||
# This could be feature extraction, global MVN, SpecAugment, etc.
|
||||
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
||||
for tnfm in self.input_transforms:
|
||||
inputs = tnfm(inputs, supervision_segments=segments)
|
||||
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"supervisions": default_collate(
|
||||
[
|
||||
self.text_sampling_func(
|
||||
texts=supervision.texts,
|
||||
pre_texts=supervision.pre_texts,
|
||||
context_list=supervision.context_list if "context_list" in supervision.custom else None,
|
||||
rare_word_list=self.rare_word_list,
|
||||
)
|
||||
if self.text_sampling_func is not None
|
||||
else {
|
||||
"text": train_text_normalization(supervision.texts[0]),
|
||||
"pre_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
"style_text": train_text_normalization(
|
||||
supervision.pre_texts[0]
|
||||
),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
for sequence_idx, cut in enumerate(cuts)
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
),
|
||||
}
|
||||
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||
batch["supervisions"].update(supervision_intervals)
|
||||
if self.return_cuts:
|
||||
batch["supervisions"]["cut"] = [
|
||||
cut for cut in cuts for sup in cut.supervisions
|
||||
]
|
||||
|
||||
has_word_alignments = all(
|
||||
s.alignment is not None and "word" in s.alignment
|
||||
for c in cuts
|
||||
for s in c.supervisions
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_asr(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
tol = 2e-3 # 1ms
|
||||
for cut in cuts:
|
||||
for supervision in cut.supervisions:
|
||||
assert supervision.start >= -tol, (
|
||||
f"Supervisions starting before the cut are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
# Supervision start time is relative to Cut ...
|
||||
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||
#
|
||||
# 'supervision.end' is end of supervision inside the Cut
|
||||
assert supervision.end <= cut.duration + tol, (
|
||||
f"Supervisions ending after the cut "
|
||||
f"are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
|
||||
def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
|
||||
"""A helper function that generates a random substring from a given string
|
||||
|
||||
Args:
|
||||
s (str): Input string
|
||||
|
||||
Returns:
|
||||
str: Returned substring
|
||||
"""
|
||||
min_len = min(len(s), min_len)
|
||||
|
||||
start = random.randint(0, len(s) - min_len)
|
||||
end = min(start + max_len, random.randint(start + min_len, len(s)))
|
||||
|
||||
return s[start:end]
|
||||
|
||||
|
||||
def triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def triplet_text_sampling2(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_pre_text](gt_pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def triplet_text_sampling_with_context_list(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str,
|
||||
rare_word_list: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is arbitrary.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), B(style_text), B(text))
|
||||
(A(pre_text), C(style_text), C(text))
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
if context_list is not None:
|
||||
context_list = context_list.lower()
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
sampling_weight = [0.7, 0.3, 0.0, 0.0]
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = get_pre_text_with_context_list2(
|
||||
text=gt_pre_text,
|
||||
pre_text=gt_pre_text,
|
||||
context_list=context_list,
|
||||
rare_words_list=rare_word_list,
|
||||
)
|
||||
pre_text = transforms[i_pre_text](pre_text)
|
||||
|
||||
if i_text == i_pre_text:
|
||||
style_text = gt_pre_text
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
else:
|
||||
# get the pre_text of same style as text
|
||||
# For now, do not do transform to the style text
|
||||
style_text = gt_pre_text
|
||||
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
||||
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def get_pre_text_with_context_list(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.5:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(0, 50)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.7:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.1:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(0,70)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.2:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.3:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def get_pre_text_with_context_list2(
|
||||
text: str,
|
||||
pre_text: str,
|
||||
context_list: str,
|
||||
rare_words_list: List[str] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# By a small proportion of time, use the substring of ref_text as pre_text
|
||||
|
||||
if context_list != "" and context_list is not None:
|
||||
v = random.random()
|
||||
if v < 0.4:
|
||||
# correct + distractors
|
||||
# sample distractors
|
||||
num_distractors = random.randint(50, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
# sample correct
|
||||
correct = context_list.split()
|
||||
i = random.randint(1, len(correct))
|
||||
correct = random.sample(correct, i)
|
||||
# combine correct and distractors
|
||||
pre_text = distractors + correct
|
||||
random.shuffle(pre_text)
|
||||
pre_text = " ".join(pre_text)
|
||||
elif v < 0.55:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
num_distractors = random.randint(50,100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
pre_text = " ".join(splitted)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
else:
|
||||
v = random.random()
|
||||
if v < 0.3:
|
||||
splitted = text.split()
|
||||
sampling_weights = [len(w)**1.2 for w in splitted]
|
||||
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
|
||||
i = random.randint(1, min(len(splitted), 20))
|
||||
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
||||
pre_text = " ".join(splitted)
|
||||
num_distractors = random.randint(50,100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
splitted += distractors
|
||||
random.shuffle(splitted) # shuffle the list
|
||||
elif v < 0.4:
|
||||
# full distractors
|
||||
num_distractors = random.randint(5, 100)
|
||||
distractors = random.sample(rare_words_list, num_distractors)
|
||||
pre_text = " ".join(distractors)
|
||||
|
||||
elif v < 0.6:
|
||||
pre_text = get_substring(text, min_len=15, max_len=150)
|
||||
else:
|
||||
pre_text = pre_text
|
||||
|
||||
return pre_text
|
||||
|
||||
|
||||
|
||||
def joint_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of pre_text, style_text
|
||||
and ref_text should **always** match.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(A(pre_text), A(style_text), A(text))
|
||||
(B(pre_text), B(style_text), B(text))
|
||||
(C(pre_text), C(style_text), C(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
When the transform of text and pre_text match, we can use the whole
|
||||
pre_text as the prompt text.
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
|
||||
|
||||
# get the normalized text and pre_text
|
||||
text = transforms[i_text](gt_text)
|
||||
pre_text = transforms[i_text](gt_pre_text)
|
||||
|
||||
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": i_text,
|
||||
}
|
||||
|
||||
|
||||
def triplet_style_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
transforms: Optional[List[Callable[[str], str]]] = None,
|
||||
min_len_style: Optional[int] = 80,
|
||||
) -> Dict[str, str]:
|
||||
"""This function generates a triplet of
|
||||
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
||||
should always match, whereas the style of pre_text is fixed to mixed-trans.
|
||||
Suppose we have 3 different transforms A,B,C, and the groundtruth
|
||||
text and pre_text are referred to as text and pre_text.
|
||||
The following three tuples are all valid:
|
||||
|
||||
(gt_pre_text, B(style_text), B(text))
|
||||
(gt_pre_text, C(style_text), C(text))
|
||||
(gt_pre_text, A(style_text), A(text))
|
||||
...
|
||||
|
||||
If transforms is not given, the following pre-defined transforms
|
||||
are available:
|
||||
0: original (normal case, with punc)
|
||||
1: recog (upper, no punc)
|
||||
2: upper_only_alpha (upper, no punc)
|
||||
3: lower_only_alpha (lower, no punc)
|
||||
4: upper_all (upper, with punc)
|
||||
5: lower_all (lower, with punc)
|
||||
|
||||
Args:
|
||||
texts (List[str]):
|
||||
A list of ref_texts whose first item is the ground truth
|
||||
text from books.
|
||||
pre_texts (List[str]):
|
||||
A list of pre_texts, whose first item is the groundtruth
|
||||
pre_text from books.
|
||||
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
||||
|
||||
Returns:
|
||||
str: A dictionary
|
||||
"""
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(texts) == len(pre_texts)
|
||||
assert len(texts) == 2
|
||||
|
||||
# we assume the first item to be ground truth
|
||||
gt_text = texts[0]
|
||||
gt_pre_text = pre_texts[0]
|
||||
|
||||
if transforms is None:
|
||||
transforms = [
|
||||
lambda x: x, # return it self
|
||||
upper_only_alpha,
|
||||
lower_only_alpha,
|
||||
lower_all_char,
|
||||
]
|
||||
|
||||
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
|
||||
total_transforms = len(transforms) # do not use the recognized trans
|
||||
|
||||
# Select a transformation randomly
|
||||
t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
|
||||
|
||||
# get the normalized text
|
||||
text = transforms[t_id](gt_text)
|
||||
# get the original un-processed style text
|
||||
style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150)
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(text),
|
||||
"pre_text": train_text_normalization(gt_pre_text),
|
||||
"style_text": train_text_normalization(style_text),
|
||||
"transform_ids": t_id,
|
||||
}
|
||||
|
||||
|
||||
def naive_triplet_text_sampling(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: str = None,
|
||||
rare_word_list: List[str] = None,
|
||||
min_len_style: Optional[int] = 120,
|
||||
):
|
||||
|
||||
return {
|
||||
"text": train_text_normalization(texts[0]),
|
||||
"pre_text": train_text_normalization(pre_texts[0]),
|
||||
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
|
||||
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
#"pre_text": "Hello, my friend. "*50,
|
||||
#"style_text": train_text_normalization(pre_texts[0][:200]),
|
||||
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
||||
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
|
||||
#"style_text": "Mixed-case English transcription, with punctuation.",
|
||||
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
|
||||
"transform_ids": 0,
|
||||
}
|
||||
|
||||
|
||||
def random_shuffle_subset(
|
||||
data: List[str],
|
||||
p: float = 0.2,
|
||||
p_mask: float = 0.05,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Randomly shuffle the subset by probability p, which means that p% of the samples
|
||||
in the original batch are shuffled, the others are kept in the original order.
|
||||
|
||||
With a probability of p_mask, replace the original string with an empty string.
|
||||
|
||||
"""
|
||||
|
||||
num_to_shuffle = int(len(data) * p)
|
||||
id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False)
|
||||
item_to_shuffle = [data[id] for id in id_to_shuffle]
|
||||
random.shuffle(item_to_shuffle)
|
||||
|
||||
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
||||
data[id] = item
|
||||
|
||||
# Randomly mask a proportion of the data to empty string
|
||||
if p_mask > 0:
|
||||
for i in range(len(data)):
|
||||
if random.random() < p_mask:
|
||||
data[i] = ""
|
||||
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
texts = [
|
||||
"AA, BB, cC, dD!",
|
||||
"AA BB CC DD",
|
||||
]
|
||||
|
||||
pre_texts = [
|
||||
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
||||
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
||||
]
|
||||
# for i in range(10):
|
||||
# print(f"Run: {i}")
|
||||
# print(triplet_text_sampling(texts, pre_texts))
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
data = [str(i) for i in range(30)]
|
||||
random.shuffle(data)
|
||||
print(data)
|
||||
for i in range(1):
|
||||
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
|
||||
print(shuffled)
|
||||
print((time.time() - start)/100)
|
@ -1,882 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Callable
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriHeavyAsrDataModule
|
||||
from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
get_params,
|
||||
get_transducer_model,
|
||||
_encode_texts_as_bytes,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_LODR
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-pre-text",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use pre-text is available during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-style-prompt",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Use style prompt when evaluation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--post-normalization",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--compute-CER",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Reports CER. By default, only reports WER",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--style-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of style prompt, i.e style_text"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pre-text-transform",
|
||||
type=str,
|
||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
||||
default="mixed-punc",
|
||||
help="The style of content prompt, i.e pre_text"
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||
"""Apply transform to a list of text. By default, the text are in
|
||||
ground truth format, i.e mixed-punc.
|
||||
|
||||
Args:
|
||||
text (List[str]): Input text string
|
||||
transform (str): Transform to be applied
|
||||
|
||||
Returns:
|
||||
List[str]: _description_
|
||||
"""
|
||||
if transform == "mixed-punc":
|
||||
return text
|
||||
elif transform == "upper-no-punc":
|
||||
return [upper_only_alpha(s) for s in text]
|
||||
elif transform == "lower-no-punc":
|
||||
return [lower_only_alpha(s) for s in text]
|
||||
elif transform == "lower-punc":
|
||||
return [lower_all_char(s) for s in text]
|
||||
else:
|
||||
raise NotImplementedError(f"Unseen transform: {transform}")
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
|
||||
set to true.
|
||||
ngram_lm:
|
||||
A ngram lm. Used in LODR decoding.
|
||||
ngram_lm_scale:
|
||||
The scale of the ngram language model.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
batch_size = feature.size(0)
|
||||
|
||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||
pre_texts = batch["supervisions"]["pre_text"]
|
||||
else:
|
||||
pre_texts = ["" for _ in range(batch_size)]
|
||||
|
||||
if params.use_style_prompt:
|
||||
style_texts = batch["supervisions"]["style_text"]
|
||||
else:
|
||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||
|
||||
# Get the text embedding input
|
||||
if params.use_pre_text or params.use_style_prompt:
|
||||
|
||||
# apply style transform to the pre_text and style_text
|
||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||
#pre_texts = [remove_non_alphabetic(t) for t in pre_texts]
|
||||
#pre_texts = random_shuffle_subset(pre_texts, p=0.98, p_mask=0.0)
|
||||
if params.use_style_prompt:
|
||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
pre_texts, pre_texts_lens, style_lens = _encode_texts_as_bytes(
|
||||
pre_texts,
|
||||
style_texts,
|
||||
device,
|
||||
max_len=1200
|
||||
) # note that the output pre_texts include style_text and actual pre_text
|
||||
|
||||
memory, memory_key_padding_mask = model.encode_text(
|
||||
text=pre_texts,
|
||||
text_lens=pre_texts_lens,
|
||||
style_lens=style_lens,
|
||||
) # (T,B,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
# Get the transducer encoder output
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
encoder_out, encoder_out_lens = model.encode_audio(
|
||||
feature=feature,
|
||||
feature_lens=feature_lens,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
||||
|
||||
# the style of ref_text should match style_text
|
||||
#if params.use_style_prompt:
|
||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = ref_text_normalization(
|
||||
ref_text
|
||||
) # remove full-width symbols & some book marks
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
test_set_cers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
if params.compute_CER:
|
||||
# Write CER statistics
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
cer = write_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
compute_CER=params.compute_CER,
|
||||
)
|
||||
test_set_cers[key] = cer
|
||||
|
||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
if params.compute_CER:
|
||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tcER", file=f)
|
||||
for key, val in test_set_cers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_cers:
|
||||
s += "{} CER\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriHeavyAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_pre_text:
|
||||
params.suffix += f"-pre-text-{params.pre_text_transform}"
|
||||
|
||||
if params.use_style_prompt:
|
||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
LM = None
|
||||
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
test_cuts = libriheavy.test_cuts()
|
||||
medium_test_cuts = libriheavy.medium_test_cuts()
|
||||
test_clean_cuts = libriheavy.test_clean_cuts()
|
||||
test_other_cuts = libriheavy.test_other_cuts()
|
||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||
|
||||
test_dl = libriheavy.valid_dataloaders(test_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
medium_test_dl = libriheavy.valid_dataloaders(medium_test_cuts, text_sampling_func=naive_triplet_text_sampling)
|
||||
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
|
||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||
|
||||
#test_sets = ["test-clean", "test-other", "ls-test-clean", "ls-test-other"]
|
||||
#test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
|
||||
|
||||
#test_sets = ["test-clean", "test-other"]
|
||||
#test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
#test_sets = ["ls-test-clean", "ls-test-other"]
|
||||
#test_dl = [ls_test_clean_dl, ls_test_other_dl]
|
||||
|
||||
if params.use_ls_test_set:
|
||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||
test_dl = [ls_test_clean_dl, ls_test_other_dl]
|
||||
else:
|
||||
test_sets = ["medium_test",]
|
||||
test_dl = [medium_test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.post_normalization:
|
||||
params.suffix += "-post-normalization"
|
||||
|
||||
new_res = {}
|
||||
for k in results_dict:
|
||||
new_ans = []
|
||||
for item in results_dict[k]:
|
||||
id, ref, hyp = item
|
||||
hyp = [remove_non_alphabetic(w.upper()) for w in hyp]
|
||||
hyp = [w for w in hyp if w != ""]
|
||||
ref = [remove_non_alphabetic(w.upper()) for w in ref]
|
||||
ref = [w for w in ref if w != ""]
|
||||
new_ans.append((id,ref,hyp))
|
||||
new_res[k] = new_ans
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=new_res,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,349 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import warnings
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
text_embed: nn.Module,
|
||||
text_encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.text_embed = text_embed
|
||||
self.text_encoder = text_encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
use_pre_text: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
text,
|
||||
text_lens,
|
||||
style_lens
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
encoder_out, x_lens = self.encoder(
|
||||
x,
|
||||
x_lens,
|
||||
src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 1.0 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
indicator = (
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
|
||||
< style_lens
|
||||
)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term[..., 0] += indicator
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text: Tensor,
|
||||
text_lens: Tensor,
|
||||
style_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
Args:
|
||||
text (Tensor): The input text data in utf-8 bytes, (N, T)
|
||||
text_lens (Tensor): The length of the input text (N, ), including style_prompt
|
||||
style_lens (Tensor): The length of the style prompt (N, )
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text = text.t() # now (T, N)
|
||||
text = self.text_embed(text) # now (T, N, C)
|
||||
text_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
text = self._add_style_indicator(text, style_lens)
|
||||
|
||||
memory, text_lens = self.text_encoder(
|
||||
text, text_lens, text_key_padding_mask
|
||||
)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
return memory, memory_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
self,
|
||||
feature: Tensor,
|
||||
feature_lens: Tensor,
|
||||
memory: Optional[Tensor],
|
||||
memory_key_padding_mask: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Encode the input audio features
|
||||
|
||||
Args:
|
||||
feature (Tensor): Input audio (N,T,C)
|
||||
feature_lens (Tensor): Length of input audio (N,)
|
||||
memory (Tensor): Embeddings from the text encoder
|
||||
memory_key_padding_mask (Tensor): _description_
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: _description_
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
Transducer = PromptedTransducer # for decoding
|
@ -1,369 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import warnings
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
Note that this is a PromptedTransducer, meaning that the transducer is able to decode
|
||||
with prompts.
|
||||
This transducer also has a special context fuser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
text_embed: nn.Module,
|
||||
text_encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
context_fuser: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
context_fuser:
|
||||
A module that fuses the output embeddings of the text_encoder. The fused embedding
|
||||
will be given to joiner before emitting symbols.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.text_embed = text_embed
|
||||
self.text_encoder = text_encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim,
|
||||
vocab_size,
|
||||
initial_scale=0.25,
|
||||
)
|
||||
self.context_fuser = context_fuser # a module to aggregate the context
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
use_pre_text: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
if use_pre_text:
|
||||
memory, memory_key_padding_mask = self.encode_text(
|
||||
text,
|
||||
text_lens,
|
||||
style_lens
|
||||
) # (T,N,C)
|
||||
else:
|
||||
memory = None
|
||||
memory_key_padding_mask = None
|
||||
|
||||
encoder_out, x_lens = self.encoder(
|
||||
x,
|
||||
x_lens,
|
||||
src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
if self.context_fuser is not None and memory is not None:
|
||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
||||
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
|
||||
context = self.joiner.context_proj(context)
|
||||
else:
|
||||
context = None
|
||||
|
||||
logits = self.joiner(
|
||||
am_pruned,
|
||||
lm_pruned,
|
||||
context=context,
|
||||
project_input=False,
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 1.0 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the embedding vector can adjust to compensate.
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
indicator = (
|
||||
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
|
||||
< style_lens
|
||||
)
|
||||
indicator = indicator.to(memory.dtype)
|
||||
|
||||
extra_term = torch.zeros_like(memory)
|
||||
extra_term[..., 0] += indicator
|
||||
|
||||
return memory + extra_term
|
||||
|
||||
def encode_text(
|
||||
self,
|
||||
text: Tensor,
|
||||
text_lens: Tensor,
|
||||
style_lens: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Get the embeddings of text
|
||||
|
||||
Args:
|
||||
text (Tensor): The input text data in utf-8 bytes, (N, T)
|
||||
text_lens (Tensor): The length of the input text (N, ), including style_prompt
|
||||
style_lens (Tensor): The length of the style prompt (N, )
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
|
||||
text_encoder and the attention mask
|
||||
"""
|
||||
text = text.t() # now (T, N)
|
||||
text = self.text_embed(text) # now (T, N, C)
|
||||
text_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
text = self._add_style_indicator(text, style_lens)
|
||||
|
||||
memory, text_lens = self.text_encoder(
|
||||
text, text_lens, text_key_padding_mask
|
||||
) # (T,N,C)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
return memory, memory_key_padding_mask
|
||||
|
||||
def encode_audio(
|
||||
self,
|
||||
feature: Tensor,
|
||||
feature_lens: Tensor,
|
||||
memory: Optional[Tensor],
|
||||
memory_key_padding_mask: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Encode the input audio features
|
||||
|
||||
Args:
|
||||
feature (Tensor): Input audio (N,T,C)
|
||||
feature_lens (Tensor): Length of input audio (N,)
|
||||
memory (Tensor): Embeddings from the text encoder
|
||||
memory_key_padding_mask (Tensor): _description_
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]: _description_
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
Transducer = PromptedTransducer # for decoding
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user