From bea1bd295fbd6baddb071747d3436dc5ce56875b Mon Sep 17 00:00:00 2001 From: marcoyang1998 Date: Tue, 19 Sep 2023 17:44:52 +0800 Subject: [PATCH] add script for generating context list for each utterance --- egs/libriheavy/ASR/prepare_prompt_asr.sh | 36 +++++ .../zipformer_prompt_asr/asr_datamodule.py | 81 +++++------- .../ASR/zipformer_prompt_asr/utils.py | 124 +++++++++++++----- 3 files changed, 157 insertions(+), 84 deletions(-) create mode 100755 egs/libriheavy/ASR/prepare_prompt_asr.sh diff --git a/egs/libriheavy/ASR/prepare_prompt_asr.sh b/egs/libriheavy/ASR/prepare_prompt_asr.sh new file mode 100755 index 000000000..b931cea26 --- /dev/null +++ b/egs/libriheavy/ASR/prepare_prompt_asr.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +stage=-1 +stop_stage=100 +manifest_dir=data/fbank +subset=medium +topk=10000 + +. shared/parse_options.sh || exit 1 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download the meta biasing list for LibriSpeech" + mkdir -p data/context_biasing + cd data/context_biasing + git clone https://github.com/facebookresearch/fbai-speech.git + cd ../.. +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Add rare-words for context biasing to the manifest" + python zipformer_prompt_asr/utils.py \ + --manifest-dir $manifest_dir \ + --subset $subset \ + --top-k $topk + +fi diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 80faba038..4b4c8a785 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -23,7 +23,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional import torch -from dataset2 import PromptASRDataset +from dataset import PromptASRDataset from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( CutConcatenate, @@ -70,14 +70,18 @@ class LibriHeavyAsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args - + if args.use_context_list: from dataset2 import PromptASRDataset + assert args.rare_word_file is not None - with open(args.rare_word_file, 'r') as f: - self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform + with open(args.rare_word_file, "r") as f: + self.rare_word_list = ( + f.read().lower().split() + ) # Use lower-cased for easier style transform else: from dataset import PromptASRDataset + self.rare_word_list = None @classmethod @@ -202,20 +206,22 @@ class LibriHeavyAsrDataModule: default="small", help="Select the Libriheavy subset (small|medium|large)", ) - + group.add_argument( "--use-context-list", type=str2bool, default=False, help="Use the context list of libri heavy", ) - + group.add_argument( - "--min-count", + "--topk-k", type=int, - default=7, + default=10000, + help="""The top-k words are identified as common words, + the rest as rare words""", ) - + group.add_argument( "--with-decoding", type=str2bool, @@ -227,12 +233,12 @@ class LibriHeavyAsrDataModule: "--random-left-padding", type=str2bool, ) - + group.add_argument( "--rare-word-file", type=str, ) - + group.add_argument( "--long-audio-cuts", type=str, @@ -257,13 +263,9 @@ class LibriHeavyAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -285,9 +287,7 @@ class LibriHeavyAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -332,9 +332,7 @@ class LibriHeavyAsrDataModule: # Drop feats to be on the safe side. train = PromptASRDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_transforms=input_transforms, return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, @@ -379,9 +377,10 @@ class LibriHeavyAsrDataModule: return train_dl - def valid_dataloaders(self, + def valid_dataloaders( + self, cuts_valid: CutSet, - text_sampling_func: Callable[[List[str]], str] = None, + text_sampling_func: Callable[[List[str]], str] = None, ) -> DataLoader: transforms = [] if self.args.random_left_padding: @@ -401,9 +400,7 @@ class LibriHeavyAsrDataModule: if self.args.on_the_fly_feats: validate = PromptASRDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), return_cuts=self.args.return_cuts, text_sampling_func=text_sampling_func, rare_word_list=self.rare_word_list, @@ -460,16 +457,16 @@ class LibriHeavyAsrDataModule: if self.args.use_context_list: path = ( self.args.manifest_dir - / f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz" + / f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz" ) elif self.args.with_decoding: path = ( - self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz" + self.args.manifest_dir + / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz" ) else: path = ( - self.args.manifest_dir - / f"libriheavy_cuts_{self.args.subset}.jsonl.gz" + self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz" ) logging.info(f"Loading manifest from {path}.") @@ -513,21 +510,7 @@ class LibriHeavyAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) - - @lru_cache() - def npr1_dev_cuts(self) -> CutSet: - logging.info("About to get npr1 dev cuts") - return load_manifest_lazy( - self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz" - ) - - @lru_cache() - def npr1_test_cuts(self) -> CutSet: - logging.info("About to get npr1 test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "npr1_cuts_test.jsonl.gz" - ) - + @lru_cache() def long_audio_cuts(self) -> CutSet: logging.info("About to get long audio cuts") @@ -535,11 +518,11 @@ class LibriHeavyAsrDataModule: self.args.long_audio_cuts, ) return cuts - + @lru_cache() def test_dev_cuts(self) -> CutSet: logging.info("About to get test dev cuts") cuts = load_manifest_lazy( self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz" ) - return cuts \ No newline at end of file + return cuts diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py index 922a1d3c0..533982519 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/utils.py @@ -1,3 +1,4 @@ +import argparse import ast import glob import logging @@ -12,9 +13,34 @@ from text_normalization import remove_non_alphabetic from tqdm import tqdm +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--manifest-dir", + type=str, + default="data/fbank", + help="Where are the manifest stored", + ) + + parser.add_argument( + "--subset", type=str, default="medium", help="Which subset to work with" + ) + + parser.add_argument( + "--top-k", + type=int, + default=10000, + help="How many words to keep", + ) + + return parser + + def get_facebook_biasing_list( test_set: str, - use_distractors: bool = False, num_distractors: int = 100, ) -> Dict: # Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf @@ -28,9 +54,9 @@ def get_facebook_biasing_list( raise ValueError(f"Unseen test set {test_set}") else: if test_set == "test-clean": - biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv" elif test_set == "test-other": - biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" + biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv" else: raise ValueError(f"Unseen test set {test_set}") @@ -41,7 +67,7 @@ def get_facebook_biasing_list( output = dict() for line in data: id, _, l1, l2 = line.split("\t") - if use_distractors: + if num_distractors > 0: # use distractors biasing_list = ast.literal_eval(l2) else: biasing_list = ast.literal_eval(l1) @@ -66,20 +92,25 @@ def brian_biasing_list(level: str): return biasing_dict -def get_rare_words(subset: str, min_count: int): +def get_rare_words( + subset: str = "medium", + top_k: int = 10000, + # min_count: int = 10000, +): """Get a list of rare words appearing less than `min_count` times Args: subset: The dataset - min_count (int): Count of appearance + top_k (int): How many frequent words """ txt_path = f"data/tmp/transcript_words_{subset}.txt" - rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" + rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" if os.path.exists(rare_word_file): print("File exists, do not proceed!") return - print("Finding rare words in the manifest.") + + print("---Identifying rare words in the manifest---") count_file = f"data/tmp/transcript_words_{subset}_count.txt" if not os.path.exists(count_file): with open(txt_path, "r") as file: @@ -94,9 +125,12 @@ def get_rare_words(subset: str, min_count: int): else: word_count[w] += 1 + word_count = list(word_count.items()) # convert to a list of tuple + word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True) with open(count_file, "w") as fout: - for w in word_count: - fout.write(f"{w}\t{word_count[w]}\n") + for w, count in word_count: + fout.write(f"{w}\t{count}\n") + else: word_count = {} with open(count_file, "r") as fin: @@ -106,42 +140,45 @@ def get_rare_words(subset: str, min_count: int): print(f"A total of {len(word_count)} words appeared!") rare_words = [] - for k in word_count: - if int(word_count[k]) <= min_count: - rare_words.append(k + "\n") - print(f"A total of {len(rare_words)} appeared <= {min_count} times") + for word, count in word_count[top_k:]: + rare_words.append(word + "\n") + print(f"A total of {len(rare_words)} are identified as rare words.") with open(rare_word_file, "w") as f: f.writelines(rare_words) -def add_context_list_to_manifest(subset: str, min_count: int): +def add_context_list_to_manifest( + manifest_dir: str, + subset: str = "medium", + top_k: int = 10000, +): """Generate a context list of rare words for each utterance in the manifest Args: + manifest_dir: Where to store the manifest with context list subset (str): Subset - min_count (int): The min appearances + top_k (int): How many frequent words """ - rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt" - manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz" - - target_manifest_dir = manifest_dir.replace( - ".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz" + orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz" + target_manifest_dir = orig_manifest_dir.replace( + ".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz" ) if os.path.exists(target_manifest_dir): print(f"Target file exits at {target_manifest_dir}!") return - print(f"Reading rare words from {rare_words_file}") + rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt" + print(f"---Reading rare words from {rare_words_file}---") with open(rare_words_file, "r") as f: rare_words = f.read() rare_words = rare_words.split("\n") rare_words = set(rare_words) print(f"A total of {len(rare_words)} rare words!") - cuts = load_manifest_lazy(manifest_dir) - print(f"Loaded manifest from {manifest_dir}") + cuts = load_manifest_lazy(orig_manifest_dir) + print(f"Loaded manifest from {orig_manifest_dir}") def _add_context(c: Cut): splits = ( @@ -157,15 +194,21 @@ def add_context_list_to_manifest(subset: str, min_count: int): return c cuts = cuts.map(_add_context) - + print(f"---Saving manifest with context list to {target_manifest_dir}---") cuts.to_file(target_manifest_dir) - print(f"Saved manifest with context list to {target_manifest_dir}") + print("Finished") -def check(subset: str, min_count: int): - # Used to show how many samples in the training set have a context list - print("Calculating the stats over the manifest") - manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz" +def check( + manifest_dir: str, + subset: str = "medium", + top_k: int = 10000, +): + # Show how many samples in the training set have a context list + # and the average length of context list + print("--- Calculating the stats over the manifest ---") + + manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz" cuts = load_manifest_lazy(manifest_dir) total_cuts = len(cuts) has_context_list = [c.supervisions[0].context_list != "" for c in cuts] @@ -378,8 +421,19 @@ def write_error_stats( if __name__ == "__main__": - subset = "medium" - min_count = 10 - get_rare_words(subset, min_count) - add_context_list_to_manifest(subset=subset, min_count=min_count) - check(subset=subset, min_count=min_count) + parser = get_parser() + args = parser.parse_args() + manifest_dir = args.manifest_dir + subset = args.subset + top_k = args.top_k + get_rare_words(subset=subset, top_k=top_k) + add_context_list_to_manifest( + manifest_dir=manifest_dir, + subset=subset, + top_k=top_k, + ) + check( + manifest_dir=manifest_dir, + subset=subset, + top_k=top_k, + )