further clean up

This commit is contained in:
marcoyang1998 2023-09-15 11:13:51 +08:00
parent ae2c7c73f6
commit a0fe6bcd0d
6 changed files with 247 additions and 4146 deletions

View File

@ -11,6 +11,7 @@ from lhotse.utils import compute_num_frames, ifnone
from torch.utils.data.dataloader import DataLoader, default_collate
from text_normalization import (
remove_non_alphabetic,
upper_only_alpha,
lower_only_alpha,
upper_all_char,
@ -33,6 +34,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
rare_word_list: Optional[List[str]] = None
):
"""
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
@ -59,6 +61,7 @@ class PromptASRDataset(torch.utils.data.Dataset):
# a text sampling function
self.text_sampling_func = text_sampling_func
self.rare_word_list = rare_word_list
def __getitem__(
self, cuts: CutSet
@ -109,6 +112,8 @@ class PromptASRDataset(torch.utils.data.Dataset):
self.text_sampling_func(
texts=supervision.texts,
pre_texts=supervision.pre_texts,
context_list=supervision.context_list if "context_list" in supervision.custom else None,
rare_word_list=self.rare_word_list,
)
if self.text_sampling_func is not None
else {
@ -183,6 +188,8 @@ def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
def triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
@ -238,7 +245,8 @@ def triplet_text_sampling(
lower_all_char,
]
sampling_weight = [0.7, 0.3, 0.0, 0.0] # Mixed-punc should have the largest sampling prob
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
@ -266,7 +274,8 @@ def triplet_text_sampling(
}
def multi_ref_text_triplet_text_sampling(
def triplet_text_sampling2(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
@ -310,14 +319,12 @@ def multi_ref_text_triplet_text_sampling(
Returns:
str: A dictionary
"""
assert len(texts) == 3
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth, the third item to be the
# decoding results with prompts
if random.random() < 0.5:
# we assume the first item to be ground truth
gt_text = texts[0]
else:
gt_text = texts[2] # decoding res with prompt
gt_pre_text = pre_texts[0]
if transforms is None:
@ -328,7 +335,8 @@ def multi_ref_text_triplet_text_sampling(
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
@ -356,6 +364,234 @@ def multi_ref_text_triplet_text_sampling(
}
def triplet_text_sampling_with_context_list(
texts: List[str],
pre_texts: List[str],
context_list: str,
rare_word_list: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
if context_list is not None:
context_list = context_list.lower()
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = get_pre_text_with_context_list2(
text=gt_pre_text,
pre_text=gt_pre_text,
context_list=context_list,
rare_words_list=rare_word_list,
)
pre_text = transforms[i_pre_text](pre_text)
if i_text == i_pre_text:
style_text = gt_pre_text
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def get_pre_text_with_context_list(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.5:
# correct + distractors
# sample distractors
num_distractors = random.randint(0, 50)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
elif v < 0.7:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
else:
v = random.random()
if v < 0.1:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
pre_text = pre_text
return pre_text
def get_pre_text_with_context_list2(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.4:
# correct + distractors
# sample distractors
num_distractors = random.randint(50, 100)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
elif v < 0.55:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(50,100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
else:
v = random.random()
if v < 0.3:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(50,100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.4:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.6:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
pre_text = pre_text
return pre_text
def joint_triplet_text_sampling(
texts: List[str],
@ -521,7 +757,6 @@ def naive_triplet_text_sampling(
return {
"text": train_text_normalization(texts[0]),
"pre_text": train_text_normalization(pre_texts[0]),
#"pre_text": "",
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"pre_text": "Hello, my friend. "*50,
@ -529,7 +764,7 @@ def naive_triplet_text_sampling(
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"style_text": "Mixed-case English transcription, with punctuation.",
#"style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
"transform_ids": 0,
}

View File

@ -1,824 +0,0 @@
from typing import Callable, Dict, List, Optional, Union
import random
import numpy as np
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from torch.utils.data.dataloader import DataLoader, default_collate
from text_normalization import (
remove_non_alphabetic,
upper_only_alpha,
lower_only_alpha,
upper_all_char,
lower_all_char,
train_text_normalization,
)
class PromptASRDataset(torch.utils.data.Dataset):
"""This is a dataset for Prompt ASR. It supports the following features:
1. Select a tuple of (text, pre_text, style_text) randomly from a
list of texts as supervisions.
"""
def __init__(
self,
return_cuts: bool = False,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
rare_word_list: Optional[List[str]] = None
):
"""
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
for more details.
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
objects used to create that batch.
:param cut_transforms: A list of transforms to be applied on each sampled batch,
before converting cuts to an input representation (audio/features).
Examples: cut concatenation, noise cuts mixing, etc.
:param input_transforms: A list of transforms to be applied on each sampled batch,
after the cuts are converted to audio/features.
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
:param text_sampling_func: Sampling a text as transcription from a list of texts.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
# a text sampling function
self.text_sampling_func = text_sampling_func
self.rare_word_list = rare_word_list
def __getitem__(
self, cuts: CutSet
) -> Dict[str, Union[torch.Tensor, List[str]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
"""
validate_for_asr(cuts)
# Sort the cuts by duration so that the first one determines the batch time dimensions.
cuts = cuts.sort_by_duration(ascending=False)
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
# the supervision boundaries.
for tnfm in self.cut_transforms:
cuts = tnfm(cuts)
# Sort the cuts again after transforms
cuts = cuts.sort_by_duration(ascending=False)
# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl
# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
# Apply all available transforms on the inputs, i.e. either audio or features.
# This could be feature extraction, global MVN, SpecAugment, etc.
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
self.text_sampling_func(
texts=supervision.texts,
pre_texts=supervision.pre_texts,
context_list=supervision.context_list if "context_list" in supervision.custom else None,
rare_word_list=self.rare_word_list,
)
if self.text_sampling_func is not None
else {
"text": train_text_normalization(supervision.texts[0]),
"pre_text": train_text_normalization(
supervision.pre_texts[0]
),
"style_text": train_text_normalization(
supervision.pre_texts[0]
),
"transform_ids": 0,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
batch["supervisions"]["cut"] = [
cut for cut in cuts for sup in cut.supervisions
]
has_word_alignments = all(
s.alignment is not None and "word" in s.alignment
for c in cuts
for s in c.supervisions
)
return batch
def validate_for_asr(cuts: CutSet) -> None:
validate(cuts)
tol = 2e-3 # 1ms
for cut in cuts:
for supervision in cut.supervisions:
assert supervision.start >= -tol, (
f"Supervisions starting before the cut are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
#
# 'supervision.end' is end of supervision inside the Cut
assert supervision.end <= cut.duration + tol, (
f"Supervisions ending after the cut "
f"are not supported for ASR"
f" (sup id: {supervision.id}, cut id: {cut.id})"
)
def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
"""A helper function that generates a random substring from a given string
Args:
s (str): Input string
Returns:
str: Returned substring
"""
min_len = min(len(s), min_len)
start = random.randint(0, len(s) - min_len)
end = min(start + max_len, random.randint(start + min_len, len(s)))
return s[start:end]
def triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_pre_text](gt_pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_text_sampling2(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_pre_text](gt_pre_text)
if i_text == i_pre_text:
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_text_sampling_with_context_list(
texts: List[str],
pre_texts: List[str],
context_list: str,
rare_word_list: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is arbitrary.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), B(style_text), B(text))
(A(pre_text), C(style_text), C(text))
(A(pre_text), A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
if context_list is not None:
context_list = context_list.lower()
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
# sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
sampling_weight = [0.7, 0.3, 0.0, 0.0]
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = get_pre_text_with_context_list2(
text=gt_pre_text,
pre_text=gt_pre_text,
context_list=context_list,
rare_words_list=rare_word_list,
)
pre_text = transforms[i_pre_text](pre_text)
if i_text == i_pre_text:
style_text = gt_pre_text
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
else:
# get the pre_text of same style as text
# For now, do not do transform to the style text
style_text = gt_pre_text
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def get_pre_text_with_context_list(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.5:
# correct + distractors
# sample distractors
num_distractors = random.randint(0, 50)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
elif v < 0.7:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
else:
v = random.random()
if v < 0.1:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(0,70)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.2:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.3:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
pre_text = pre_text
return pre_text
def get_pre_text_with_context_list2(
text: str,
pre_text: str,
context_list: str,
rare_words_list: List[str] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# By a small proportion of time, use the substring of ref_text as pre_text
if context_list != "" and context_list is not None:
v = random.random()
if v < 0.4:
# correct + distractors
# sample distractors
num_distractors = random.randint(50, 100)
distractors = random.sample(rare_words_list, num_distractors)
# sample correct
correct = context_list.split()
i = random.randint(1, len(correct))
correct = random.sample(correct, i)
# combine correct and distractors
pre_text = distractors + correct
random.shuffle(pre_text)
pre_text = " ".join(pre_text)
elif v < 0.55:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
num_distractors = random.randint(50,100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
pre_text = " ".join(splitted)
else:
pre_text = pre_text
else:
v = random.random()
if v < 0.3:
splitted = text.split()
sampling_weights = [len(w)**1.2 for w in splitted]
sampling_weights = [p/sum(sampling_weights) for p in sampling_weights]
i = random.randint(1, min(len(splitted), 20))
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
pre_text = " ".join(splitted)
num_distractors = random.randint(50,100)
distractors = random.sample(rare_words_list, num_distractors)
splitted += distractors
random.shuffle(splitted) # shuffle the list
elif v < 0.4:
# full distractors
num_distractors = random.randint(5, 100)
distractors = random.sample(rare_words_list, num_distractors)
pre_text = " ".join(distractors)
elif v < 0.6:
pre_text = get_substring(text, min_len=15, max_len=150)
else:
pre_text = pre_text
return pre_text
def joint_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of pre_text, style_text
and ref_text should **always** match.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(A(pre_text), A(style_text), A(text))
(B(pre_text), B(style_text), B(text))
(C(pre_text), C(style_text), C(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
When the transform of text and pre_text match, we can use the whole
pre_text as the prompt text.
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
i_text = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
# get the normalized text and pre_text
text = transforms[i_text](gt_text)
pre_text = transforms[i_text](gt_pre_text)
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": i_text,
}
def triplet_style_text_sampling(
texts: List[str],
pre_texts: List[str],
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should always match, whereas the style of pre_text is fixed to mixed-trans.
Suppose we have 3 different transforms A,B,C, and the groundtruth
text and pre_text are referred to as text and pre_text.
The following three tuples are all valid:
(gt_pre_text, B(style_text), B(text))
(gt_pre_text, C(style_text), C(text))
(gt_pre_text, A(style_text), A(text))
...
If transforms is not given, the following pre-defined transforms
are available:
0: original (normal case, with punc)
1: recog (upper, no punc)
2: upper_only_alpha (upper, no punc)
3: lower_only_alpha (lower, no punc)
4: upper_all (upper, with punc)
5: lower_all (lower, with punc)
Args:
texts (List[str]):
A list of ref_texts whose first item is the ground truth
text from books.
pre_texts (List[str]):
A list of pre_texts, whose first item is the groundtruth
pre_text from books.
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
Returns:
str: A dictionary
"""
# import pdb; pdb.set_trace()
assert len(texts) == len(pre_texts)
assert len(texts) == 2
# we assume the first item to be ground truth
gt_text = texts[0]
gt_pre_text = pre_texts[0]
if transforms is None:
transforms = [
lambda x: x, # return it self
upper_only_alpha,
lower_only_alpha,
lower_all_char,
]
sampling_weight = [0.5, 0.2, 0.15, 0.15] # Mixed-punc should have the largest sampling prob
total_transforms = len(transforms) # do not use the recognized trans
# Select a transformation randomly
t_id = np.random.choice(total_transforms, 1, p=sampling_weight)[0]
# get the normalized text
text = transforms[t_id](gt_text)
# get the original un-processed style text
style_text = get_substring(gt_pre_text, min_len=min_len_style, max_len=150)
return {
"text": train_text_normalization(text),
"pre_text": train_text_normalization(gt_pre_text),
"style_text": train_text_normalization(style_text),
"transform_ids": t_id,
}
def naive_triplet_text_sampling(
texts: List[str],
pre_texts: List[str],
context_list: str = None,
rare_word_list: List[str] = None,
min_len_style: Optional[int] = 120,
):
return {
"text": train_text_normalization(texts[0]),
"pre_text": train_text_normalization(pre_texts[0]),
#"pre_text": "HELLO IT THIS ENOUGH FOR THE MODEL TO LEARN THE STYLE",
#"pre_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"pre_text": "Hello, my friend. "*50,
#"style_text": train_text_normalization(pre_texts[0][:200]),
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
#"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related.",
#"style_text": "Mixed-case English transcription, with punctuation.",
# "style_text": train_text_normalization(get_substring(pre_texts[0], min_len=min_len_style)),
"transform_ids": 0,
}
def random_shuffle_subset(
data: List[str],
p: float = 0.2,
p_mask: float = 0.05,
) -> List[str]:
"""
Randomly shuffle the subset by probability p, which means that p% of the samples
in the original batch are shuffled, the others are kept in the original order.
With a probability of p_mask, replace the original string with an empty string.
"""
num_to_shuffle = int(len(data) * p)
id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False)
item_to_shuffle = [data[id] for id in id_to_shuffle]
random.shuffle(item_to_shuffle)
for id, item in zip(id_to_shuffle, item_to_shuffle):
data[id] = item
# Randomly mask a proportion of the data to empty string
if p_mask > 0:
for i in range(len(data)):
if random.random() < p_mask:
data[i] = ""
return data
if __name__ == "__main__":
texts = [
"AA, BB, cC, dD!",
"AA BB CC DD",
]
pre_texts = [
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
]
# for i in range(10):
# print(f"Run: {i}")
# print(triplet_text_sampling(texts, pre_texts))
import time
start = time.time()
data = [str(i) for i in range(30)]
random.shuffle(data)
print(data)
for i in range(1):
shuffled = random_shuffle_subset(data=data, p=0.4, p_mask=0.1)
print(shuffled)
print((time.time() - start)/100)

View File

@ -1,882 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
import math
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from dataset import naive_triplet_text_sampling, random_shuffle_subset
from text_normalization import ref_text_normalization, remove_non_alphabetic, upper_only_alpha, upper_all_char, lower_all_char, lower_only_alpha
from train import (
add_model_arguments,
get_params,
get_transducer_model,
_encode_texts_as_bytes,
)
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_lm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_LODR
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--use-pre-text",
type=str2bool,
default=True,
help="Use pre-text is available during decoding",
)
parser.add_argument(
"--use-style-prompt",
type=str2bool,
default=True,
help="Use style prompt when evaluation"
)
parser.add_argument(
"--post-normalization",
type=str2bool,
default=True,
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
)
parser.add_argument(
"--compute-CER",
type=str2bool,
default=True,
help="Reports CER. By default, only reports WER",
)
parser.add_argument(
"--style-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
default="mixed-punc",
help="The style of style prompt, i.e style_text"
)
parser.add_argument(
"--pre-text-transform",
type=str,
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
default="mixed-punc",
help="The style of content prompt, i.e pre_text"
)
add_model_arguments(parser)
return parser
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
"""Apply transform to a list of text. By default, the text are in
ground truth format, i.e mixed-punc.
Args:
text (List[str]): Input text string
transform (str): Transform to be applied
Returns:
List[str]: _description_
"""
if transform == "mixed-punc":
return text
elif transform == "upper-no-punc":
return [upper_only_alpha(s) for s in text]
elif transform == "lower-no-punc":
return [lower_only_alpha(s) for s in text]
elif transform == "lower-punc":
return [lower_all_char(s) for s in text]
else:
raise NotImplementedError(f"Unseen transform: {transform}")
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural net LM for shallow fusion. Only used when `--use-shallow-fusion`
set to true.
ngram_lm:
A ngram lm. Used in LODR decoding.
ngram_lm_scale:
The scale of the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
batch_size = feature.size(0)
if "pre_text" in batch["supervisions"] and params.use_pre_text:
pre_texts = batch["supervisions"]["pre_text"]
else:
pre_texts = ["" for _ in range(batch_size)]
if params.use_style_prompt:
style_texts = batch["supervisions"]["style_text"]
else:
style_texts = ["" for _ in range(batch_size)] # use empty string
# Get the text embedding input
if params.use_pre_text or params.use_style_prompt:
# apply style transform to the pre_text and style_text
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
#pre_texts = [remove_non_alphabetic(t) for t in pre_texts]
#pre_texts = random_shuffle_subset(pre_texts, p=0.98, p_mask=0.0)
if params.use_style_prompt:
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pre_texts, pre_texts_lens, style_lens = _encode_texts_as_bytes(
pre_texts,
style_texts,
device,
max_len=1200
) # note that the output pre_texts include style_text and actual pre_text
memory, memory_key_padding_mask = model.encode_text(
text=pre_texts,
text_lens=pre_texts_lens,
style_lens=style_lens,
) # (T,B,C)
else:
memory = None
memory_key_padding_mask = None
# Get the transducer encoder output
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
encoder_out, encoder_out_lens = model.encode_audio(
feature=feature,
feature_lens=feature_lens,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
hyps = []
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network LM, used during shallow fusion
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
# the style of ref_text should match style_text
#if params.use_style_prompt:
texts = _apply_style_transform(texts, params.style_text_transform)
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = ref_text_normalization(
ref_text
) # remove full-width symbols & some book marks
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
test_set_cers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if params.compute_CER:
# Write CER statistics
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
store_transcripts(filename=recog_path, texts=results, char_level=True)
errs_filename = (
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
cer = write_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=params.compute_CER,
)
test_set_cers[key] = cer
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
if params.compute_CER:
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tcER", file=f)
for key, val in test_set_cers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_cers:
s += "{} CER\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriHeavyAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_pre_text:
params.suffix += f"-pre-text-{params.pre_text_transform}"
if params.use_style_prompt:
params.suffix += f"-style-prompt-{params.style_text_transform}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
LM = None
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
libriheavy = LibriHeavyAsrDataModule(args)
test_cuts = libriheavy.test_cuts()
medium_test_cuts = libriheavy.medium_test_cuts()
test_clean_cuts = libriheavy.test_clean_cuts()
test_other_cuts = libriheavy.test_other_cuts()
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
test_dl = libriheavy.valid_dataloaders(test_cuts, text_sampling_func=naive_triplet_text_sampling)
medium_test_dl = libriheavy.valid_dataloaders(medium_test_cuts, text_sampling_func=naive_triplet_text_sampling)
test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
#test_sets = ["test-clean", "test-other", "ls-test-clean", "ls-test-other"]
#test_dl = [test_clean_dl, test_other_dl, ls_test_clean_dl, ls_test_other_dl]
#test_sets = ["test-clean", "test-other"]
#test_dl = [test_clean_dl, test_other_dl]
#test_sets = ["ls-test-clean", "ls-test-other"]
#test_dl = [ls_test_clean_dl, ls_test_other_dl]
if params.use_ls_test_set:
test_sets = ["ls-test-clean", "ls-test-other"]
test_dl = [ls_test_clean_dl, ls_test_other_dl]
else:
test_sets = ["medium_test",]
test_dl = [medium_test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
if params.post_normalization:
params.suffix += "-post-normalization"
new_res = {}
for k in results_dict:
new_ans = []
for item in results_dict[k]:
id, ref, hyp = item
hyp = [remove_non_alphabetic(w.upper()) for w in hyp]
hyp = [w for w in hyp if w != ""]
ref = [remove_non_alphabetic(w.upper()) for w in ref]
ref = [w for w in ref if w != ""]
new_ans.append((id,ref,hyp))
new_res[k] = new_ans
save_results(
params=params,
test_set_name=test_set,
results_dict=new_res,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -1,349 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple
class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
text_embed: nn.Module,
text_encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.text_embed = text_embed
self.text_encoder = text_encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
text: torch.Tensor,
text_lens: torch.Tensor,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_pre_text: bool = True,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
text:
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content
prompt.
text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the
style plus the content prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(
text,
text_lens,
style_lens
)
else:
memory = None
memory_key_padding_mask = None
encoder_out, x_lens = self.encoder(
x,
x_lens,
src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
"""
Adds to `memory` an indicator that is 1.0 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the embedding vector can adjust to compensate.
Args:
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
"""
(memory_len, batch_size, embed_dim) = memory.shape
indicator = (
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
< style_lens
)
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator
return memory + extra_term
def encode_text(
self,
text: Tensor,
text_lens: Tensor,
style_lens: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text
Args:
text (Tensor): The input text data in utf-8 bytes, (N, T)
text_lens (Tensor): The length of the input text (N, ), including style_prompt
style_lens (Tensor): The length of the style prompt (N, )
Returns:
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask
"""
text = text.t() # now (T, N)
text = self.text_embed(text) # now (T, N, C)
text_key_padding_mask = make_pad_mask(text_lens)
text = self._add_style_indicator(text, style_lens)
memory, text_lens = self.text_encoder(
text, text_lens, text_key_padding_mask
)
memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask
def encode_audio(
self,
feature: Tensor,
feature_lens: Tensor,
memory: Optional[Tensor],
memory_key_padding_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
"""Encode the input audio features
Args:
feature (Tensor): Input audio (N,T,C)
feature_lens (Tensor): Length of input audio (N,)
memory (Tensor): Embeddings from the text encoder
memory_key_padding_mask (Tensor): _description_
Returns:
Tuple[Tensor, Tensor]: _description_
"""
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x=x,
x_lens=x_lens,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
Transducer = PromptedTransducer # for decoding

View File

@ -1,369 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
import random
import warnings
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear
from torch import Tensor
from typing import Optional, Tuple
class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
Note that this is a PromptedTransducer, meaning that the transducer is able to decode
with prompts.
This transducer also has a special context fuser.
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
text_embed: nn.Module,
text_encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
context_fuser: nn.Module = None,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
context_fuser:
A module that fuses the output embeddings of the text_encoder. The fused embedding
will be given to joiner before emitting symbols.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.text_embed = text_embed
self.text_encoder = text_encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
self.context_fuser = context_fuser # a module to aggregate the context
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
text: torch.Tensor,
text_lens: torch.Tensor,
style_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
use_pre_text: bool = True,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
text:
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content
prompt.
text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the
style plus the content prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
if use_pre_text:
memory, memory_key_padding_mask = self.encode_text(
text,
text_lens,
style_lens
) # (T,N,C)
else:
memory = None
memory_key_padding_mask = None
encoder_out, x_lens = self.encoder(
x,
x_lens,
src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
if self.context_fuser is not None and memory is not None:
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
context = self.context_fuser(memory, padding_mask=memory_key_padding_mask)
context = self.joiner.context_proj(context)
else:
context = None
logits = self.joiner(
am_pruned,
lm_pruned,
context=context,
project_input=False,
)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
"""
Adds to `memory` an indicator that is 1.0 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the embedding vector can adjust to compensate.
Args:
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
"""
(memory_len, batch_size, embed_dim) = memory.shape
indicator = (
torch.arange(memory_len, device=memory.device).unsqueeze(-1)
< style_lens
)
indicator = indicator.to(memory.dtype)
extra_term = torch.zeros_like(memory)
extra_term[..., 0] += indicator
return memory + extra_term
def encode_text(
self,
text: Tensor,
text_lens: Tensor,
style_lens: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Get the embeddings of text
Args:
text (Tensor): The input text data in utf-8 bytes, (N, T)
text_lens (Tensor): The length of the input text (N, ), including style_prompt
style_lens (Tensor): The length of the style prompt (N, )
Returns:
Tuple[Tensor, Tensor]: Returns the text embeddings encoded by the
text_encoder and the attention mask
"""
text = text.t() # now (T, N)
text = self.text_embed(text) # now (T, N, C)
text_key_padding_mask = make_pad_mask(text_lens)
text = self._add_style_indicator(text, style_lens)
memory, text_lens = self.text_encoder(
text, text_lens, text_key_padding_mask
) # (T,N,C)
memory_key_padding_mask = make_pad_mask(text_lens)
return memory, memory_key_padding_mask
def encode_audio(
self,
feature: Tensor,
feature_lens: Tensor,
memory: Optional[Tensor],
memory_key_padding_mask: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
"""Encode the input audio features
Args:
feature (Tensor): Input audio (N,T,C)
feature_lens (Tensor): Length of input audio (N,)
memory (Tensor): Embeddings from the text encoder
memory_key_padding_mask (Tensor): _description_
Returns:
Tuple[Tensor, Tensor]: _description_
"""
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x=x,
x_lens=x_lens,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
Transducer = PromptedTransducer # for decoding

File diff suppressed because it is too large Load Diff