add script for generating context list for each utterance

This commit is contained in:
marcoyang1998 2023-09-19 17:44:52 +08:00
parent 8401f26342
commit bea1bd295f
3 changed files with 157 additions and 84 deletions

View File

@ -0,0 +1,36 @@
#!/usr/bin/env bash
set -eou pipefail
# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
stage=-1
stop_stage=100
manifest_dir=data/fbank
subset=medium
topk=10000
. shared/parse_options.sh || exit 1
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download the meta biasing list for LibriSpeech"
mkdir -p data/context_biasing
cd data/context_biasing
git clone https://github.com/facebookresearch/fbai-speech.git
cd ../..
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Add rare-words for context biasing to the manifest"
python zipformer_prompt_asr/utils.py \
--manifest-dir $manifest_dir \
--subset $subset \
--top-k $topk
fi

View File

@ -23,7 +23,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import torch
from dataset2 import PromptASRDataset
from dataset import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
@ -70,14 +70,18 @@ class LibriHeavyAsrDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
if args.use_context_list:
from dataset2 import PromptASRDataset
assert args.rare_word_file is not None
with open(args.rare_word_file, 'r') as f:
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
with open(args.rare_word_file, "r") as f:
self.rare_word_list = (
f.read().lower().split()
) # Use lower-cased for easier style transform
else:
from dataset import PromptASRDataset
self.rare_word_list = None
@classmethod
@ -202,20 +206,22 @@ class LibriHeavyAsrDataModule:
default="small",
help="Select the Libriheavy subset (small|medium|large)",
)
group.add_argument(
"--use-context-list",
type=str2bool,
default=False,
help="Use the context list of libri heavy",
)
group.add_argument(
"--min-count",
"--topk-k",
type=int,
default=7,
default=10000,
help="""The top-k words are identified as common words,
the rest as rare words""",
)
group.add_argument(
"--with-decoding",
type=str2bool,
@ -227,12 +233,12 @@ class LibriHeavyAsrDataModule:
"--random-left-padding",
type=str2bool,
)
group.add_argument(
"--rare-word-file",
type=str,
)
group.add_argument(
"--long-audio-cuts",
type=str,
@ -257,13 +263,9 @@ class LibriHeavyAsrDataModule:
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "musan_cuts.jsonl.gz"
)
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -285,9 +287,7 @@ class LibriHeavyAsrDataModule:
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
@ -332,9 +332,7 @@ class LibriHeavyAsrDataModule:
# Drop feats to be on the safe side.
train = PromptASRDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
@ -379,9 +377,10 @@ class LibriHeavyAsrDataModule:
return train_dl
def valid_dataloaders(self,
def valid_dataloaders(
self,
cuts_valid: CutSet,
text_sampling_func: Callable[[List[str]], str] = None,
text_sampling_func: Callable[[List[str]], str] = None,
) -> DataLoader:
transforms = []
if self.args.random_left_padding:
@ -401,9 +400,7 @@ class LibriHeavyAsrDataModule:
if self.args.on_the_fly_feats:
validate = PromptASRDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
text_sampling_func=text_sampling_func,
rare_word_list=self.rare_word_list,
@ -460,16 +457,16 @@ class LibriHeavyAsrDataModule:
if self.args.use_context_list:
path = (
self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz"
)
elif self.args.with_decoding:
path = (
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
)
else:
path = (
self.args.manifest_dir
/ f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
)
logging.info(f"Loading manifest from {path}.")
@ -513,21 +510,7 @@ class LibriHeavyAsrDataModule:
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)
@lru_cache()
def npr1_dev_cuts(self) -> CutSet:
logging.info("About to get npr1 dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz"
)
@lru_cache()
def npr1_test_cuts(self) -> CutSet:
logging.info("About to get npr1 test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "npr1_cuts_test.jsonl.gz"
)
@lru_cache()
def long_audio_cuts(self) -> CutSet:
logging.info("About to get long audio cuts")
@ -535,11 +518,11 @@ class LibriHeavyAsrDataModule:
self.args.long_audio_cuts,
)
return cuts
@lru_cache()
def test_dev_cuts(self) -> CutSet:
logging.info("About to get test dev cuts")
cuts = load_manifest_lazy(
self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz"
)
return cuts
return cuts

View File

@ -1,3 +1,4 @@
import argparse
import ast
import glob
import logging
@ -12,9 +13,34 @@ from text_normalization import remove_non_alphabetic
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/fbank",
help="Where are the manifest stored",
)
parser.add_argument(
"--subset", type=str, default="medium", help="Which subset to work with"
)
parser.add_argument(
"--top-k",
type=int,
default=10000,
help="How many words to keep",
)
return parser
def get_facebook_biasing_list(
test_set: str,
use_distractors: bool = False,
num_distractors: int = 100,
) -> Dict:
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
@ -28,9 +54,9 @@ def get_facebook_biasing_list(
raise ValueError(f"Unseen test set {test_set}")
else:
if test_set == "test-clean":
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
elif test_set == "test-other":
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
else:
raise ValueError(f"Unseen test set {test_set}")
@ -41,7 +67,7 @@ def get_facebook_biasing_list(
output = dict()
for line in data:
id, _, l1, l2 = line.split("\t")
if use_distractors:
if num_distractors > 0: # use distractors
biasing_list = ast.literal_eval(l2)
else:
biasing_list = ast.literal_eval(l1)
@ -66,20 +92,25 @@ def brian_biasing_list(level: str):
return biasing_dict
def get_rare_words(subset: str, min_count: int):
def get_rare_words(
subset: str = "medium",
top_k: int = 10000,
# min_count: int = 10000,
):
"""Get a list of rare words appearing less than `min_count` times
Args:
subset: The dataset
min_count (int): Count of appearance
top_k (int): How many frequent words
"""
txt_path = f"data/tmp/transcript_words_{subset}.txt"
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
if os.path.exists(rare_word_file):
print("File exists, do not proceed!")
return
print("Finding rare words in the manifest.")
print("---Identifying rare words in the manifest---")
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
if not os.path.exists(count_file):
with open(txt_path, "r") as file:
@ -94,9 +125,12 @@ def get_rare_words(subset: str, min_count: int):
else:
word_count[w] += 1
word_count = list(word_count.items()) # convert to a list of tuple
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
with open(count_file, "w") as fout:
for w in word_count:
fout.write(f"{w}\t{word_count[w]}\n")
for w, count in word_count:
fout.write(f"{w}\t{count}\n")
else:
word_count = {}
with open(count_file, "r") as fin:
@ -106,42 +140,45 @@ def get_rare_words(subset: str, min_count: int):
print(f"A total of {len(word_count)} words appeared!")
rare_words = []
for k in word_count:
if int(word_count[k]) <= min_count:
rare_words.append(k + "\n")
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
for word, count in word_count[top_k:]:
rare_words.append(word + "\n")
print(f"A total of {len(rare_words)} are identified as rare words.")
with open(rare_word_file, "w") as f:
f.writelines(rare_words)
def add_context_list_to_manifest(subset: str, min_count: int):
def add_context_list_to_manifest(
manifest_dir: str,
subset: str = "medium",
top_k: int = 10000,
):
"""Generate a context list of rare words for each utterance in the manifest
Args:
manifest_dir: Where to store the manifest with context list
subset (str): Subset
min_count (int): The min appearances
top_k (int): How many frequent words
"""
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
target_manifest_dir = manifest_dir.replace(
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz"
target_manifest_dir = orig_manifest_dir.replace(
".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz"
)
if os.path.exists(target_manifest_dir):
print(f"Target file exits at {target_manifest_dir}!")
return
print(f"Reading rare words from {rare_words_file}")
rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
print(f"---Reading rare words from {rare_words_file}---")
with open(rare_words_file, "r") as f:
rare_words = f.read()
rare_words = rare_words.split("\n")
rare_words = set(rare_words)
print(f"A total of {len(rare_words)} rare words!")
cuts = load_manifest_lazy(manifest_dir)
print(f"Loaded manifest from {manifest_dir}")
cuts = load_manifest_lazy(orig_manifest_dir)
print(f"Loaded manifest from {orig_manifest_dir}")
def _add_context(c: Cut):
splits = (
@ -157,15 +194,21 @@ def add_context_list_to_manifest(subset: str, min_count: int):
return c
cuts = cuts.map(_add_context)
print(f"---Saving manifest with context list to {target_manifest_dir}---")
cuts.to_file(target_manifest_dir)
print(f"Saved manifest with context list to {target_manifest_dir}")
print("Finished")
def check(subset: str, min_count: int):
# Used to show how many samples in the training set have a context list
print("Calculating the stats over the manifest")
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz"
def check(
manifest_dir: str,
subset: str = "medium",
top_k: int = 10000,
):
# Show how many samples in the training set have a context list
# and the average length of context list
print("--- Calculating the stats over the manifest ---")
manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz"
cuts = load_manifest_lazy(manifest_dir)
total_cuts = len(cuts)
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
@ -378,8 +421,19 @@ def write_error_stats(
if __name__ == "__main__":
subset = "medium"
min_count = 10
get_rare_words(subset, min_count)
add_context_list_to_manifest(subset=subset, min_count=min_count)
check(subset=subset, min_count=min_count)
parser = get_parser()
args = parser.parse_args()
manifest_dir = args.manifest_dir
subset = args.subset
top_k = args.top_k
get_rare_words(subset=subset, top_k=top_k)
add_context_list_to_manifest(
manifest_dir=manifest_dir,
subset=subset,
top_k=top_k,
)
check(
manifest_dir=manifest_dir,
subset=subset,
top_k=top_k,
)