add script for generating context list for each utterance

This commit is contained in:
marcoyang1998 2023-09-19 17:44:52 +08:00
parent 8401f26342
commit bea1bd295f
3 changed files with 157 additions and 84 deletions

View File

@ -0,0 +1,36 @@
#!/usr/bin/env bash
set -eou pipefail
# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
stage=-1
stop_stage=100
manifest_dir=data/fbank
subset=medium
topk=10000
. shared/parse_options.sh || exit 1
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download the meta biasing list for LibriSpeech"
mkdir -p data/context_biasing
cd data/context_biasing
git clone https://github.com/facebookresearch/fbai-speech.git
cd ../..
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Add rare-words for context biasing to the manifest"
python zipformer_prompt_asr/utils.py \
--manifest-dir $manifest_dir \
--subset $subset \
--top-k $topk
fi

View File

@ -23,7 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from dataset2 import PromptASRDataset from dataset import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
CutConcatenate, CutConcatenate,
@ -73,11 +73,15 @@ class LibriHeavyAsrDataModule:
if args.use_context_list: if args.use_context_list:
from dataset2 import PromptASRDataset from dataset2 import PromptASRDataset
assert args.rare_word_file is not None assert args.rare_word_file is not None
with open(args.rare_word_file, 'r') as f: with open(args.rare_word_file, "r") as f:
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform self.rare_word_list = (
f.read().lower().split()
) # Use lower-cased for easier style transform
else: else:
from dataset import PromptASRDataset from dataset import PromptASRDataset
self.rare_word_list = None self.rare_word_list = None
@classmethod @classmethod
@ -211,9 +215,11 @@ class LibriHeavyAsrDataModule:
) )
group.add_argument( group.add_argument(
"--min-count", "--topk-k",
type=int, type=int,
default=7, default=10000,
help="""The top-k words are identified as common words,
the rest as rare words""",
) )
group.add_argument( group.add_argument(
@ -257,13 +263,9 @@ class LibriHeavyAsrDataModule:
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest( cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
transforms.append( transforms.append(
CutMix( CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -285,9 +287,7 @@ class LibriHeavyAsrDataModule:
input_transforms = [] input_transforms = []
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
logging.info( logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
# Set the value of num_frame_masks according to Lhotse's version. # Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is # In different Lhotse's versions, the default of num_frame_masks is
# different. # different.
@ -332,9 +332,7 @@ class LibriHeavyAsrDataModule:
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
train = PromptASRDataset( train = PromptASRDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func, text_sampling_func=text_sampling_func,
@ -379,7 +377,8 @@ class LibriHeavyAsrDataModule:
return train_dl return train_dl
def valid_dataloaders(self, def valid_dataloaders(
self,
cuts_valid: CutSet, cuts_valid: CutSet,
text_sampling_func: Callable[[List[str]], str] = None, text_sampling_func: Callable[[List[str]], str] = None,
) -> DataLoader: ) -> DataLoader:
@ -401,9 +400,7 @@ class LibriHeavyAsrDataModule:
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
validate = PromptASRDataset( validate = PromptASRDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func, text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list, rare_word_list=self.rare_word_list,
@ -460,16 +457,16 @@ class LibriHeavyAsrDataModule:
if self.args.use_context_list: if self.args.use_context_list:
path = ( path = (
self.args.manifest_dir self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz" / f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz"
) )
elif self.args.with_decoding: elif self.args.with_decoding:
path = ( path = (
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz" self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
) )
else: else:
path = ( path = (
self.args.manifest_dir self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
/ f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
) )
logging.info(f"Loading manifest from {path}.") logging.info(f"Loading manifest from {path}.")
@ -514,20 +511,6 @@ class LibriHeavyAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
) )
@lru_cache()
def npr1_dev_cuts(self) -> CutSet:
logging.info("About to get npr1 dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz"
)
@lru_cache()
def npr1_test_cuts(self) -> CutSet:
logging.info("About to get npr1 test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "npr1_cuts_test.jsonl.gz"
)
@lru_cache() @lru_cache()
def long_audio_cuts(self) -> CutSet: def long_audio_cuts(self) -> CutSet:
logging.info("About to get long audio cuts") logging.info("About to get long audio cuts")

View File

@ -1,3 +1,4 @@
import argparse
import ast import ast
import glob import glob
import logging import logging
@ -12,9 +13,34 @@ from text_normalization import remove_non_alphabetic
from tqdm import tqdm from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/fbank",
help="Where are the manifest stored",
)
parser.add_argument(
"--subset", type=str, default="medium", help="Which subset to work with"
)
parser.add_argument(
"--top-k",
type=int,
default=10000,
help="How many words to keep",
)
return parser
def get_facebook_biasing_list( def get_facebook_biasing_list(
test_set: str, test_set: str,
use_distractors: bool = False,
num_distractors: int = 100, num_distractors: int = 100,
) -> Dict: ) -> Dict:
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
@ -28,9 +54,9 @@ def get_facebook_biasing_list(
raise ValueError(f"Unseen test set {test_set}") raise ValueError(f"Unseen test set {test_set}")
else: else:
if test_set == "test-clean": if test_set == "test-clean":
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
elif test_set == "test-other": elif test_set == "test-other":
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
else: else:
raise ValueError(f"Unseen test set {test_set}") raise ValueError(f"Unseen test set {test_set}")
@ -41,7 +67,7 @@ def get_facebook_biasing_list(
output = dict() output = dict()
for line in data: for line in data:
id, _, l1, l2 = line.split("\t") id, _, l1, l2 = line.split("\t")
if use_distractors: if num_distractors > 0: # use distractors
biasing_list = ast.literal_eval(l2) biasing_list = ast.literal_eval(l2)
else: else:
biasing_list = ast.literal_eval(l1) biasing_list = ast.literal_eval(l1)
@ -66,20 +92,25 @@ def brian_biasing_list(level: str):
return biasing_dict return biasing_dict
def get_rare_words(subset: str, min_count: int): def get_rare_words(
subset: str = "medium",
top_k: int = 10000,
# min_count: int = 10000,
):
"""Get a list of rare words appearing less than `min_count` times """Get a list of rare words appearing less than `min_count` times
Args: Args:
subset: The dataset subset: The dataset
min_count (int): Count of appearance top_k (int): How many frequent words
""" """
txt_path = f"data/tmp/transcript_words_{subset}.txt" txt_path = f"data/tmp/transcript_words_{subset}.txt"
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
if os.path.exists(rare_word_file): if os.path.exists(rare_word_file):
print("File exists, do not proceed!") print("File exists, do not proceed!")
return return
print("Finding rare words in the manifest.")
print("---Identifying rare words in the manifest---")
count_file = f"data/tmp/transcript_words_{subset}_count.txt" count_file = f"data/tmp/transcript_words_{subset}_count.txt"
if not os.path.exists(count_file): if not os.path.exists(count_file):
with open(txt_path, "r") as file: with open(txt_path, "r") as file:
@ -94,9 +125,12 @@ def get_rare_words(subset: str, min_count: int):
else: else:
word_count[w] += 1 word_count[w] += 1
word_count = list(word_count.items()) # convert to a list of tuple
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
with open(count_file, "w") as fout: with open(count_file, "w") as fout:
for w in word_count: for w, count in word_count:
fout.write(f"{w}\t{word_count[w]}\n") fout.write(f"{w}\t{count}\n")
else: else:
word_count = {} word_count = {}
with open(count_file, "r") as fin: with open(count_file, "r") as fin:
@ -106,42 +140,45 @@ def get_rare_words(subset: str, min_count: int):
print(f"A total of {len(word_count)} words appeared!") print(f"A total of {len(word_count)} words appeared!")
rare_words = [] rare_words = []
for k in word_count: for word, count in word_count[top_k:]:
if int(word_count[k]) <= min_count: rare_words.append(word + "\n")
rare_words.append(k + "\n") print(f"A total of {len(rare_words)} are identified as rare words.")
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
with open(rare_word_file, "w") as f: with open(rare_word_file, "w") as f:
f.writelines(rare_words) f.writelines(rare_words)
def add_context_list_to_manifest(subset: str, min_count: int): def add_context_list_to_manifest(
manifest_dir: str,
subset: str = "medium",
top_k: int = 10000,
):
"""Generate a context list of rare words for each utterance in the manifest """Generate a context list of rare words for each utterance in the manifest
Args: Args:
manifest_dir: Where to store the manifest with context list
subset (str): Subset subset (str): Subset
min_count (int): The min appearances top_k (int): How many frequent words
""" """
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz"
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz" target_manifest_dir = orig_manifest_dir.replace(
".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz"
target_manifest_dir = manifest_dir.replace(
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
) )
if os.path.exists(target_manifest_dir): if os.path.exists(target_manifest_dir):
print(f"Target file exits at {target_manifest_dir}!") print(f"Target file exits at {target_manifest_dir}!")
return return
print(f"Reading rare words from {rare_words_file}") rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
print(f"---Reading rare words from {rare_words_file}---")
with open(rare_words_file, "r") as f: with open(rare_words_file, "r") as f:
rare_words = f.read() rare_words = f.read()
rare_words = rare_words.split("\n") rare_words = rare_words.split("\n")
rare_words = set(rare_words) rare_words = set(rare_words)
print(f"A total of {len(rare_words)} rare words!") print(f"A total of {len(rare_words)} rare words!")
cuts = load_manifest_lazy(manifest_dir) cuts = load_manifest_lazy(orig_manifest_dir)
print(f"Loaded manifest from {manifest_dir}") print(f"Loaded manifest from {orig_manifest_dir}")
def _add_context(c: Cut): def _add_context(c: Cut):
splits = ( splits = (
@ -157,15 +194,21 @@ def add_context_list_to_manifest(subset: str, min_count: int):
return c return c
cuts = cuts.map(_add_context) cuts = cuts.map(_add_context)
print(f"---Saving manifest with context list to {target_manifest_dir}---")
cuts.to_file(target_manifest_dir) cuts.to_file(target_manifest_dir)
print(f"Saved manifest with context list to {target_manifest_dir}") print("Finished")
def check(subset: str, min_count: int): def check(
# Used to show how many samples in the training set have a context list manifest_dir: str,
print("Calculating the stats over the manifest") subset: str = "medium",
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz" top_k: int = 10000,
):
# Show how many samples in the training set have a context list
# and the average length of context list
print("--- Calculating the stats over the manifest ---")
manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz"
cuts = load_manifest_lazy(manifest_dir) cuts = load_manifest_lazy(manifest_dir)
total_cuts = len(cuts) total_cuts = len(cuts)
has_context_list = [c.supervisions[0].context_list != "" for c in cuts] has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
@ -378,8 +421,19 @@ def write_error_stats(
if __name__ == "__main__": if __name__ == "__main__":
subset = "medium" parser = get_parser()
min_count = 10 args = parser.parse_args()
get_rare_words(subset, min_count) manifest_dir = args.manifest_dir
add_context_list_to_manifest(subset=subset, min_count=min_count) subset = args.subset
check(subset=subset, min_count=min_count) top_k = args.top_k
get_rare_words(subset=subset, top_k=top_k)
add_context_list_to_manifest(
manifest_dir=manifest_dir,
subset=subset,
top_k=top_k,
)
check(
manifest_dir=manifest_dir,
subset=subset,
top_k=top_k,
)