mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* Add PromptASR with BERT as text encoder * Support using word-list based content prompts for context biasing * Upload the pretrained models to huggingface * Add usage example
587 lines
22 KiB
Python
587 lines
22 KiB
Python
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
|
#
|
|
# See ../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import random
|
|
from typing import Callable, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from lhotse import validate
|
|
from lhotse.cut import CutSet
|
|
from lhotse.dataset import K2SpeechRecognitionDataset
|
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
|
from lhotse.utils import compute_num_frames, ifnone
|
|
from text_normalization import (
|
|
lower_all_char,
|
|
lower_only_alpha,
|
|
remove_non_alphabetic,
|
|
train_text_normalization,
|
|
upper_all_char,
|
|
upper_only_alpha,
|
|
)
|
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
|
|
|
|
|
class PromptASRDataset(torch.utils.data.Dataset):
|
|
"""This is a dataset for Prompt ASR. It supports the following features:
|
|
1. Select a tuple of (text, pre_text, style_text) randomly from a
|
|
list of texts as supervisions.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
return_cuts: bool = False,
|
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
|
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
input_strategy: BatchIO = PrecomputedFeatures(),
|
|
text_sampling_func: Optional[Callable[[List[str]], str]] = None,
|
|
rare_word_list: Optional[List[str]] = None,
|
|
):
|
|
"""
|
|
Icefall ASR IterableDataset constructor. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py
|
|
for more details.
|
|
|
|
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
|
objects used to create that batch.
|
|
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
|
before converting cuts to an input representation (audio/features).
|
|
Examples: cut concatenation, noise cuts mixing, etc.
|
|
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
|
after the cuts are converted to audio/features.
|
|
Examples: normalization, SpecAugment, etc.
|
|
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
|
By default, reads pre-computed features from disk.
|
|
:param text_sampling_func: Sampling a text as transcription from a list of texts.
|
|
"""
|
|
super().__init__()
|
|
# Initialize the fields
|
|
self.return_cuts = return_cuts
|
|
self.cut_transforms = ifnone(cut_transforms, [])
|
|
self.input_transforms = ifnone(input_transforms, [])
|
|
self.input_strategy = input_strategy
|
|
|
|
# a text sampling function
|
|
self.text_sampling_func = text_sampling_func
|
|
self.rare_word_list = rare_word_list
|
|
|
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
|
"""
|
|
Return a new batch, with the batch size automatically determined using the constraints
|
|
of max_frames and max_cuts.
|
|
"""
|
|
validate_for_asr(cuts)
|
|
|
|
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
|
cuts = cuts.sort_by_duration(ascending=False)
|
|
|
|
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
|
|
# the supervision boundaries.
|
|
for tnfm in self.cut_transforms:
|
|
cuts = tnfm(cuts)
|
|
|
|
# Sort the cuts again after transforms
|
|
cuts = cuts.sort_by_duration(ascending=False)
|
|
|
|
# Get a tensor with batched feature matrices, shape (B, T, F)
|
|
# Collation performs auto-padding, if necessary.
|
|
input_tpl = self.input_strategy(cuts)
|
|
if len(input_tpl) == 3:
|
|
# An input strategy with fault tolerant audio reading mode.
|
|
# "cuts" may be a subset of the original "cuts" variable,
|
|
# that only has cuts for which we succesfully read the audio.
|
|
inputs, _, cuts = input_tpl
|
|
else:
|
|
inputs, _ = input_tpl
|
|
|
|
# Get a dict of tensors that encode the positional information about supervisions
|
|
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
|
# "start_frame/sample" and "num_frames/samples".
|
|
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
|
|
|
# Apply all available transforms on the inputs, i.e. either audio or features.
|
|
# This could be feature extraction, global MVN, SpecAugment, etc.
|
|
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
|
for tnfm in self.input_transforms:
|
|
inputs = tnfm(inputs, supervision_segments=segments)
|
|
|
|
batch = {
|
|
"inputs": inputs,
|
|
"supervisions": default_collate(
|
|
[
|
|
self.text_sampling_func(
|
|
texts=supervision.texts,
|
|
pre_texts=supervision.pre_texts,
|
|
context_list=supervision.context_list
|
|
if "context_list" in supervision.custom
|
|
else None,
|
|
rare_word_list=self.rare_word_list,
|
|
)
|
|
if self.text_sampling_func is not None
|
|
else {
|
|
"text": train_text_normalization(supervision.texts[0]),
|
|
"pre_text": train_text_normalization(supervision.pre_texts[0]),
|
|
"style_text": train_text_normalization(
|
|
supervision.pre_texts[0]
|
|
),
|
|
"transform_ids": 0,
|
|
}
|
|
for sequence_idx, cut in enumerate(cuts)
|
|
for supervision in cut.supervisions
|
|
]
|
|
),
|
|
}
|
|
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
|
batch["supervisions"].update(supervision_intervals)
|
|
if self.return_cuts:
|
|
batch["supervisions"]["cut"] = [
|
|
cut for cut in cuts for sup in cut.supervisions
|
|
]
|
|
|
|
has_word_alignments = all(
|
|
s.alignment is not None and "word" in s.alignment
|
|
for c in cuts
|
|
for s in c.supervisions
|
|
)
|
|
|
|
return batch
|
|
|
|
|
|
def validate_for_asr(cuts: CutSet) -> None:
|
|
validate(cuts)
|
|
tol = 2e-3 # 1ms
|
|
for cut in cuts:
|
|
for supervision in cut.supervisions:
|
|
assert supervision.start >= -tol, (
|
|
f"Supervisions starting before the cut are not supported for ASR"
|
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
|
)
|
|
|
|
# Supervision start time is relative to Cut ...
|
|
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
|
#
|
|
# 'supervision.end' is end of supervision inside the Cut
|
|
assert supervision.end <= cut.duration + tol, (
|
|
f"Supervisions ending after the cut "
|
|
f"are not supported for ASR"
|
|
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
|
)
|
|
|
|
|
|
def get_substring(s: str, min_len: int = 40, max_len: int = 250) -> str:
|
|
"""A helper function that generates a random substring from a given string
|
|
|
|
Args:
|
|
s (str): Input string
|
|
|
|
Returns:
|
|
str: Returned substring
|
|
"""
|
|
min_len = min(len(s), min_len)
|
|
|
|
start = random.randint(0, len(s) - min_len)
|
|
end = min(start + max_len, random.randint(start + min_len, len(s)))
|
|
|
|
return s[start:end]
|
|
|
|
|
|
def triplet_text_sampling(
|
|
texts: List[str],
|
|
pre_texts: List[str],
|
|
context_list: Optional[str] = None,
|
|
rare_word_list: Optional[List[str]] = None,
|
|
transforms: Optional[List[Callable[[str], str]]] = None,
|
|
min_len_style: Optional[int] = 80,
|
|
) -> Dict[str, str]:
|
|
"""This function generates a triplet of
|
|
(pre_text, style_text, ref_text). The style of style_text and ref_text
|
|
should **always** match, whereas the style of pre_text is arbitrary.
|
|
Suppose we have 2 different transforms A,B, and the preceding text is
|
|
referred to as pre_text. The following three tuples are all valid:
|
|
|
|
(A(pre_text), A(style_text), A(ref_text))
|
|
(A(pre_text), B(style_text), B(ref_text))
|
|
(A(pre_text), A(style_text), A(ref_text))
|
|
(B(pre_text), B(style_text), B(ref_text))
|
|
|
|
If transforms is not given, the following pre-defined transforms
|
|
are available:
|
|
0: original (mixed-cased, with punc)
|
|
1: upper_only_alpha (upper-cased, no punc)
|
|
|
|
When the transform of text and pre_text match, we can use the whole
|
|
pre_text as the prompt text.
|
|
|
|
Args:
|
|
texts (List[str]):
|
|
A list of ref_texts whose first item is the ground truth
|
|
text from books.
|
|
pre_texts (List[str]):
|
|
A list of pre_texts, whose first item is the groundtruth
|
|
pre_text from books.
|
|
context_list: Optional[str] = None,
|
|
A list of biasing words separated by space
|
|
rare_word_list: Optional[str] = None,
|
|
A list of rare-words separated by space (used as distractors)
|
|
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
|
|
|
Returns:
|
|
A dictionary of ref_text, pre_text, style_text
|
|
"""
|
|
assert len(texts) == len(pre_texts)
|
|
assert len(texts) == 2
|
|
|
|
# we assume the first item to be ground truth
|
|
gt_text = texts[0]
|
|
gt_pre_text = pre_texts[0]
|
|
|
|
if transforms is None:
|
|
transforms = [
|
|
lambda x: x, # return it self
|
|
upper_only_alpha,
|
|
lower_only_alpha,
|
|
lower_all_char,
|
|
]
|
|
|
|
sampling_weight = [
|
|
0.7,
|
|
0.3,
|
|
0.0,
|
|
0.0,
|
|
] # Mixed-punc should have the largest sampling prob
|
|
|
|
total_transforms = len(transforms) # do not use the recognized trans
|
|
|
|
# Randomly sample transforms
|
|
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
|
|
|
# get the normalized text and pre_text
|
|
text = transforms[i_text](gt_text)
|
|
pre_text = transforms[i_pre_text](gt_pre_text)
|
|
|
|
if i_text == i_pre_text:
|
|
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
|
else:
|
|
# get the pre_text of same style as text
|
|
# For now, **don't** do transform to the style text, because we do it after the dataloader
|
|
style_text = gt_pre_text
|
|
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
|
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
|
|
|
return {
|
|
"text": train_text_normalization(text),
|
|
"pre_text": train_text_normalization(pre_text),
|
|
"style_text": train_text_normalization(style_text),
|
|
"transform_ids": i_text,
|
|
}
|
|
|
|
|
|
def triplet_text_sampling_with_context_list(
|
|
texts: List[str],
|
|
pre_texts: List[str],
|
|
context_list: str,
|
|
rare_word_list: List[str],
|
|
transforms: Optional[List[Callable[[str], str]]] = None,
|
|
min_len_style: Optional[int] = 80,
|
|
) -> Dict[str, str]:
|
|
"""This function generates a triplet of
|
|
(pre_text, style_text, ref_text). The pre_text is either the preceding text
|
|
or a list of words (context words + distractors).
|
|
The style of style_text and ref_text should **always** match, whereas
|
|
the style of pre_text is arbitrary.
|
|
Suppose we have 2 different transforms A,B, and the preceding text is
|
|
referred to as pre_text. The following three tuples are all valid:
|
|
|
|
(A(pre_text), A(style_text), A(ref_text))
|
|
(A(pre_text), B(style_text), B(ref_text))
|
|
(A(pre_text), A(style_text), A(ref_text))
|
|
(B(pre_text), B(style_text), B(ref_text))
|
|
|
|
If transforms is not given, the following pre-defined transforms
|
|
are available:
|
|
0: original (mixed-cased, with punc)
|
|
1: upper_only_alpha (upper-cased, no punc)
|
|
|
|
When the transform of text and pre_text match, we can use the whole
|
|
pre_text as the prompt text.
|
|
|
|
Args:
|
|
texts (List[str]):
|
|
A list of ref_texts whose first item is the ground truth
|
|
text from books.
|
|
pre_texts (List[str]):
|
|
A list of pre_texts, whose first item is the groundtruth
|
|
pre_text from books.
|
|
context_list: Optional[str] = None,
|
|
A list of biasing words separated by space
|
|
rare_word_list: Optional[str] = None,
|
|
A list of rare-words separated by space (used as distractors)
|
|
transforms (List[Callable[[str], str]]): A list of possible transforms to be applied
|
|
|
|
Returns:
|
|
A dictionary of ref_text, pre_text, style_text
|
|
Returns:
|
|
str: A dictionary
|
|
"""
|
|
# import pdb; pdb.set_trace()
|
|
assert len(texts) == len(pre_texts)
|
|
assert len(texts) == 2
|
|
|
|
if context_list is not None:
|
|
context_list = context_list.lower()
|
|
|
|
# we assume the first item to be ground truth
|
|
gt_text = texts[0]
|
|
gt_pre_text = pre_texts[0]
|
|
|
|
if transforms is None:
|
|
transforms = [
|
|
lambda x: x, # return it self
|
|
upper_only_alpha,
|
|
lower_only_alpha,
|
|
lower_all_char,
|
|
]
|
|
|
|
sampling_weight = [
|
|
0.7,
|
|
0.3,
|
|
0.0,
|
|
0.0,
|
|
] # Mixed-punc should have the largest sampling prob
|
|
|
|
total_transforms = len(transforms) # do not use the recognized trans
|
|
|
|
# Select a transformation randomly
|
|
i_text, i_pre_text = np.random.choice(total_transforms, 2, p=sampling_weight)
|
|
|
|
# get the normalized text and pre_text
|
|
text = transforms[i_text](gt_text)
|
|
pre_text = get_pre_text_with_context_list2(
|
|
text=gt_text,
|
|
pre_text=gt_pre_text,
|
|
context_list=context_list,
|
|
rare_words_list=rare_word_list,
|
|
)
|
|
pre_text = transforms[i_pre_text](pre_text)
|
|
|
|
if i_text == i_pre_text:
|
|
style_text = get_substring(pre_text, min_len=min_len_style, max_len=150)
|
|
else:
|
|
# get the pre_text of same style as text
|
|
# For now, **don't** do transform to the style text
|
|
style_text = gt_pre_text
|
|
# style_text = pre_texts[i_text] if i_text <= 1 else transforms[i_text-2](gt_pre_text)
|
|
style_text = get_substring(style_text, min_len=min_len_style, max_len=150)
|
|
|
|
return {
|
|
"text": train_text_normalization(text),
|
|
"pre_text": train_text_normalization(pre_text),
|
|
"style_text": train_text_normalization(style_text),
|
|
"transform_ids": i_text,
|
|
}
|
|
|
|
|
|
def get_pre_text_with_context_list(
|
|
text: str,
|
|
pre_text: str,
|
|
context_list: str,
|
|
rare_words_list: List[str] = None,
|
|
) -> str:
|
|
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
|
# By a small proportion of time, use the substring of ref_text as pre_text
|
|
|
|
if context_list != "" and context_list is not None:
|
|
v = random.random()
|
|
if v < 0.5:
|
|
# correct + distractors
|
|
# sample distractors
|
|
num_distractors = random.randint(0, 50)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
# sample correct
|
|
correct = context_list.split()
|
|
i = random.randint(1, len(correct))
|
|
correct = random.sample(correct, i)
|
|
# combine correct and distractors
|
|
pre_text = distractors + correct
|
|
random.shuffle(pre_text)
|
|
pre_text = " ".join(pre_text)
|
|
elif v < 0.7:
|
|
splitted = text.split()
|
|
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
|
i = random.randint(1, min(len(splitted), 20))
|
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
|
num_distractors = random.randint(0, 70)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
splitted += distractors
|
|
random.shuffle(splitted) # shuffle the list
|
|
pre_text = " ".join(splitted)
|
|
else:
|
|
pre_text = pre_text
|
|
else:
|
|
v = random.random()
|
|
if v < 0.1:
|
|
splitted = text.split()
|
|
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
|
i = random.randint(1, min(len(splitted), 20))
|
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
|
pre_text = " ".join(splitted)
|
|
num_distractors = random.randint(0, 70)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
splitted += distractors
|
|
random.shuffle(splitted) # shuffle the list
|
|
elif v < 0.2:
|
|
# full distractors
|
|
num_distractors = random.randint(5, 100)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
pre_text = " ".join(distractors)
|
|
|
|
elif v < 0.3:
|
|
pre_text = get_substring(text, min_len=15, max_len=150)
|
|
else:
|
|
pre_text = pre_text
|
|
|
|
return pre_text
|
|
|
|
|
|
def get_pre_text_with_context_list2(
|
|
text: str,
|
|
pre_text: str,
|
|
context_list: str,
|
|
rare_words_list: List[str] = None,
|
|
) -> str:
|
|
# Get the pre_text, either the ground truth preceding text or
|
|
# a list of words consisting of biasing words and distrators
|
|
# By a small proportion of time, use the substring of ref_text as pre_text
|
|
|
|
if context_list != "" and context_list is not None:
|
|
v = random.random()
|
|
if v < 0.4:
|
|
# sample distractors
|
|
num_distractors = random.randint(50, 100)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
# sample correct
|
|
correct = context_list.split()
|
|
i = random.randint(1, len(correct))
|
|
correct = random.sample(correct, i)
|
|
# combine correct and distractors
|
|
pre_text = distractors + correct
|
|
random.shuffle(pre_text)
|
|
pre_text = " ".join(pre_text)
|
|
elif v < 0.55:
|
|
splitted = text.split()
|
|
sampling_weights = [
|
|
len(w) ** 1.2 for w in splitted
|
|
] # longer words with higher weights
|
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
|
i = random.randint(1, min(len(splitted), 20))
|
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
|
num_distractors = random.randint(50, 100)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
splitted += distractors
|
|
random.shuffle(splitted) # shuffle the list
|
|
pre_text = " ".join(splitted)
|
|
else:
|
|
pre_text = pre_text
|
|
else:
|
|
v = random.random()
|
|
if v < 0.3:
|
|
splitted = text.split()
|
|
sampling_weights = [len(w) ** 1.2 for w in splitted]
|
|
sampling_weights = [p / sum(sampling_weights) for p in sampling_weights]
|
|
i = random.randint(1, min(len(splitted), 20))
|
|
splitted = list(np.random.choice(splitted, i, p=sampling_weights))
|
|
pre_text = " ".join(splitted)
|
|
num_distractors = random.randint(50, 100)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
splitted += distractors
|
|
random.shuffle(splitted) # shuffle the list
|
|
elif v < 0.4:
|
|
# full distractors
|
|
num_distractors = random.randint(5, 100)
|
|
distractors = random.sample(rare_words_list, num_distractors)
|
|
pre_text = " ".join(distractors)
|
|
elif v < 0.6:
|
|
pre_text = get_substring(text, min_len=15, max_len=150)
|
|
else:
|
|
pre_text = pre_text
|
|
|
|
return pre_text
|
|
|
|
|
|
def naive_triplet_text_sampling(
|
|
texts: List[str],
|
|
pre_texts: List[str],
|
|
context_list: str = None,
|
|
rare_word_list: List[str] = None,
|
|
min_len_style: Optional[int] = 120,
|
|
):
|
|
# The most simplest text sampling function, used only for
|
|
# evaluation, use a fixed sentence as the style text
|
|
|
|
return {
|
|
"text": train_text_normalization(texts[0]),
|
|
"pre_text": train_text_normalization(pre_texts[0]),
|
|
"style_text": "Mixed-case English transcription, with punctuation. Actually, it is fully not related. What do you think?",
|
|
"transform_ids": 0,
|
|
}
|
|
|
|
|
|
def random_shuffle_subset(
|
|
data: List[str],
|
|
p: float = 0.2,
|
|
p_mask: float = 0.05,
|
|
) -> List[str]:
|
|
"""
|
|
Randomly shuffle the subset by probability `p`, which means that p% of the samples
|
|
in the original batch are shuffled, the others are kept in the original order.
|
|
|
|
With a probability of `p_mask`, replace the original string with an empty string.
|
|
|
|
"""
|
|
|
|
num_to_shuffle = int(len(data) * p)
|
|
id_to_shuffle = np.random.choice(len(data), num_to_shuffle, replace=False)
|
|
item_to_shuffle = [data[id] for id in id_to_shuffle]
|
|
random.shuffle(item_to_shuffle)
|
|
|
|
for id, item in zip(id_to_shuffle, item_to_shuffle):
|
|
data[id] = item
|
|
|
|
# Randomly mask a proportion of the data to empty string
|
|
if p_mask > 0:
|
|
for i in range(len(data)):
|
|
if random.random() < p_mask:
|
|
data[i] = ""
|
|
|
|
return data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
texts = [
|
|
"AA, BB, cC, dD!",
|
|
"AA BB CC DD",
|
|
]
|
|
|
|
pre_texts = [
|
|
"EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg? EE, Ff, Gg?",
|
|
"EE FF GG EE FF GG EE FF GG EE FF GG EE FF GG",
|
|
]
|
|
for i in range(10):
|
|
print(f"Run: {i}")
|
|
print(triplet_text_sampling(texts, pre_texts))
|