mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
add script for generating context list for each utterance
This commit is contained in:
parent
8401f26342
commit
bea1bd295f
36
egs/libriheavy/ASR/prepare_prompt_asr.sh
Executable file
36
egs/libriheavy/ASR/prepare_prompt_asr.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
manifest_dir=data/fbank
|
||||
subset=medium
|
||||
topk=10000
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download the meta biasing list for LibriSpeech"
|
||||
mkdir -p data/context_biasing
|
||||
cd data/context_biasing
|
||||
git clone https://github.com/facebookresearch/fbai-speech.git
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Add rare-words for context biasing to the manifest"
|
||||
python zipformer_prompt_asr/utils.py \
|
||||
--manifest-dir $manifest_dir \
|
||||
--subset $subset \
|
||||
--top-k $topk
|
||||
|
||||
fi
|
@ -23,7 +23,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from dataset2 import PromptASRDataset
|
||||
from dataset import PromptASRDataset
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
CutConcatenate,
|
||||
@ -73,11 +73,15 @@ class LibriHeavyAsrDataModule:
|
||||
|
||||
if args.use_context_list:
|
||||
from dataset2 import PromptASRDataset
|
||||
|
||||
assert args.rare_word_file is not None
|
||||
with open(args.rare_word_file, 'r') as f:
|
||||
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
|
||||
with open(args.rare_word_file, "r") as f:
|
||||
self.rare_word_list = (
|
||||
f.read().lower().split()
|
||||
) # Use lower-cased for easier style transform
|
||||
else:
|
||||
from dataset import PromptASRDataset
|
||||
|
||||
self.rare_word_list = None
|
||||
|
||||
@classmethod
|
||||
@ -211,9 +215,11 @@ class LibriHeavyAsrDataModule:
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--min-count",
|
||||
"--topk-k",
|
||||
type=int,
|
||||
default=7,
|
||||
default=10000,
|
||||
help="""The top-k words are identified as common words,
|
||||
the rest as rare words""",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
@ -257,13 +263,9 @@ class LibriHeavyAsrDataModule:
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
||||
)
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
)
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -285,9 +287,7 @@ class LibriHeavyAsrDataModule:
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(
|
||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
||||
)
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
@ -332,9 +332,7 @@ class LibriHeavyAsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
@ -379,7 +377,8 @@ class LibriHeavyAsrDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self,
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
text_sampling_func: Callable[[List[str]], str] = None,
|
||||
) -> DataLoader:
|
||||
@ -401,9 +400,7 @@ class LibriHeavyAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = PromptASRDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
text_sampling_func=text_sampling_func,
|
||||
rare_word_list=self.rare_word_list,
|
||||
@ -460,16 +457,16 @@ class LibriHeavyAsrDataModule:
|
||||
if self.args.use_context_list:
|
||||
path = (
|
||||
self.args.manifest_dir
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz"
|
||||
)
|
||||
elif self.args.with_decoding:
|
||||
path = (
|
||||
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
|
||||
self.args.manifest_dir
|
||||
/ f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
|
||||
)
|
||||
else:
|
||||
path = (
|
||||
self.args.manifest_dir
|
||||
/ f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
|
||||
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
|
||||
)
|
||||
|
||||
logging.info(f"Loading manifest from {path}.")
|
||||
@ -514,20 +511,6 @@ class LibriHeavyAsrDataModule:
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def npr1_dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get npr1 dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def npr1_test_cuts(self) -> CutSet:
|
||||
logging.info("About to get npr1 test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "npr1_cuts_test.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def long_audio_cuts(self) -> CutSet:
|
||||
logging.info("About to get long audio cuts")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import ast
|
||||
import glob
|
||||
import logging
|
||||
@ -12,9 +13,34 @@ from text_normalization import remove_non_alphabetic
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="Where are the manifest stored",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--subset", type=str, default="medium", help="Which subset to work with"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="How many words to keep",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_facebook_biasing_list(
|
||||
test_set: str,
|
||||
use_distractors: bool = False,
|
||||
num_distractors: int = 100,
|
||||
) -> Dict:
|
||||
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
||||
@ -28,9 +54,9 @@ def get_facebook_biasing_list(
|
||||
raise ValueError(f"Unseen test set {test_set}")
|
||||
else:
|
||||
if test_set == "test-clean":
|
||||
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
||||
elif test_set == "test-other":
|
||||
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
||||
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
||||
else:
|
||||
raise ValueError(f"Unseen test set {test_set}")
|
||||
|
||||
@ -41,7 +67,7 @@ def get_facebook_biasing_list(
|
||||
output = dict()
|
||||
for line in data:
|
||||
id, _, l1, l2 = line.split("\t")
|
||||
if use_distractors:
|
||||
if num_distractors > 0: # use distractors
|
||||
biasing_list = ast.literal_eval(l2)
|
||||
else:
|
||||
biasing_list = ast.literal_eval(l1)
|
||||
@ -66,20 +92,25 @@ def brian_biasing_list(level: str):
|
||||
return biasing_dict
|
||||
|
||||
|
||||
def get_rare_words(subset: str, min_count: int):
|
||||
def get_rare_words(
|
||||
subset: str = "medium",
|
||||
top_k: int = 10000,
|
||||
# min_count: int = 10000,
|
||||
):
|
||||
"""Get a list of rare words appearing less than `min_count` times
|
||||
|
||||
Args:
|
||||
subset: The dataset
|
||||
min_count (int): Count of appearance
|
||||
top_k (int): How many frequent words
|
||||
"""
|
||||
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
||||
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
||||
rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
||||
|
||||
if os.path.exists(rare_word_file):
|
||||
print("File exists, do not proceed!")
|
||||
return
|
||||
print("Finding rare words in the manifest.")
|
||||
|
||||
print("---Identifying rare words in the manifest---")
|
||||
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
||||
if not os.path.exists(count_file):
|
||||
with open(txt_path, "r") as file:
|
||||
@ -94,9 +125,12 @@ def get_rare_words(subset: str, min_count: int):
|
||||
else:
|
||||
word_count[w] += 1
|
||||
|
||||
word_count = list(word_count.items()) # convert to a list of tuple
|
||||
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
||||
with open(count_file, "w") as fout:
|
||||
for w in word_count:
|
||||
fout.write(f"{w}\t{word_count[w]}\n")
|
||||
for w, count in word_count:
|
||||
fout.write(f"{w}\t{count}\n")
|
||||
|
||||
else:
|
||||
word_count = {}
|
||||
with open(count_file, "r") as fin:
|
||||
@ -106,42 +140,45 @@ def get_rare_words(subset: str, min_count: int):
|
||||
|
||||
print(f"A total of {len(word_count)} words appeared!")
|
||||
rare_words = []
|
||||
for k in word_count:
|
||||
if int(word_count[k]) <= min_count:
|
||||
rare_words.append(k + "\n")
|
||||
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
|
||||
for word, count in word_count[top_k:]:
|
||||
rare_words.append(word + "\n")
|
||||
print(f"A total of {len(rare_words)} are identified as rare words.")
|
||||
|
||||
with open(rare_word_file, "w") as f:
|
||||
f.writelines(rare_words)
|
||||
|
||||
|
||||
def add_context_list_to_manifest(subset: str, min_count: int):
|
||||
def add_context_list_to_manifest(
|
||||
manifest_dir: str,
|
||||
subset: str = "medium",
|
||||
top_k: int = 10000,
|
||||
):
|
||||
"""Generate a context list of rare words for each utterance in the manifest
|
||||
|
||||
Args:
|
||||
manifest_dir: Where to store the manifest with context list
|
||||
subset (str): Subset
|
||||
min_count (int): The min appearances
|
||||
top_k (int): How many frequent words
|
||||
|
||||
"""
|
||||
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
||||
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
|
||||
|
||||
target_manifest_dir = manifest_dir.replace(
|
||||
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
|
||||
orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz"
|
||||
target_manifest_dir = orig_manifest_dir.replace(
|
||||
".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz"
|
||||
)
|
||||
if os.path.exists(target_manifest_dir):
|
||||
print(f"Target file exits at {target_manifest_dir}!")
|
||||
return
|
||||
|
||||
print(f"Reading rare words from {rare_words_file}")
|
||||
rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
||||
print(f"---Reading rare words from {rare_words_file}---")
|
||||
with open(rare_words_file, "r") as f:
|
||||
rare_words = f.read()
|
||||
rare_words = rare_words.split("\n")
|
||||
rare_words = set(rare_words)
|
||||
print(f"A total of {len(rare_words)} rare words!")
|
||||
|
||||
cuts = load_manifest_lazy(manifest_dir)
|
||||
print(f"Loaded manifest from {manifest_dir}")
|
||||
cuts = load_manifest_lazy(orig_manifest_dir)
|
||||
print(f"Loaded manifest from {orig_manifest_dir}")
|
||||
|
||||
def _add_context(c: Cut):
|
||||
splits = (
|
||||
@ -157,15 +194,21 @@ def add_context_list_to_manifest(subset: str, min_count: int):
|
||||
return c
|
||||
|
||||
cuts = cuts.map(_add_context)
|
||||
|
||||
print(f"---Saving manifest with context list to {target_manifest_dir}---")
|
||||
cuts.to_file(target_manifest_dir)
|
||||
print(f"Saved manifest with context list to {target_manifest_dir}")
|
||||
print("Finished")
|
||||
|
||||
|
||||
def check(subset: str, min_count: int):
|
||||
# Used to show how many samples in the training set have a context list
|
||||
print("Calculating the stats over the manifest")
|
||||
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz"
|
||||
def check(
|
||||
manifest_dir: str,
|
||||
subset: str = "medium",
|
||||
top_k: int = 10000,
|
||||
):
|
||||
# Show how many samples in the training set have a context list
|
||||
# and the average length of context list
|
||||
print("--- Calculating the stats over the manifest ---")
|
||||
|
||||
manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz"
|
||||
cuts = load_manifest_lazy(manifest_dir)
|
||||
total_cuts = len(cuts)
|
||||
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
||||
@ -378,8 +421,19 @@ def write_error_stats(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
subset = "medium"
|
||||
min_count = 10
|
||||
get_rare_words(subset, min_count)
|
||||
add_context_list_to_manifest(subset=subset, min_count=min_count)
|
||||
check(subset=subset, min_count=min_count)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
manifest_dir = args.manifest_dir
|
||||
subset = args.subset
|
||||
top_k = args.top_k
|
||||
get_rare_words(subset=subset, top_k=top_k)
|
||||
add_context_list_to_manifest(
|
||||
manifest_dir=manifest_dir,
|
||||
subset=subset,
|
||||
top_k=top_k,
|
||||
)
|
||||
check(
|
||||
manifest_dir=manifest_dir,
|
||||
subset=subset,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user