mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 09:34:39 +00:00
add script for generating context list for each utterance
This commit is contained in:
parent
8401f26342
commit
bea1bd295f
36
egs/libriheavy/ASR/prepare_prompt_asr.sh
Executable file
36
egs/libriheavy/ASR/prepare_prompt_asr.sh
Executable file
@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
# This is the preparation recipe for PromptASR: https://arxiv.org/pdf/2309.07414
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
manifest_dir=data/fbank
|
||||||
|
subset=medium
|
||||||
|
topk=10000
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Download the meta biasing list for LibriSpeech"
|
||||||
|
mkdir -p data/context_biasing
|
||||||
|
cd data/context_biasing
|
||||||
|
git clone https://github.com/facebookresearch/fbai-speech.git
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Add rare-words for context biasing to the manifest"
|
||||||
|
python zipformer_prompt_asr/utils.py \
|
||||||
|
--manifest-dir $manifest_dir \
|
||||||
|
--subset $subset \
|
||||||
|
--top-k $topk
|
||||||
|
|
||||||
|
fi
|
@ -23,7 +23,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from dataset2 import PromptASRDataset
|
from dataset import PromptASRDataset
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
@ -70,14 +70,18 @@ class LibriHeavyAsrDataModule:
|
|||||||
|
|
||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
if args.use_context_list:
|
if args.use_context_list:
|
||||||
from dataset2 import PromptASRDataset
|
from dataset2 import PromptASRDataset
|
||||||
|
|
||||||
assert args.rare_word_file is not None
|
assert args.rare_word_file is not None
|
||||||
with open(args.rare_word_file, 'r') as f:
|
with open(args.rare_word_file, "r") as f:
|
||||||
self.rare_word_list = f.read().lower().split() # Use lower-cased for easier style transform
|
self.rare_word_list = (
|
||||||
|
f.read().lower().split()
|
||||||
|
) # Use lower-cased for easier style transform
|
||||||
else:
|
else:
|
||||||
from dataset import PromptASRDataset
|
from dataset import PromptASRDataset
|
||||||
|
|
||||||
self.rare_word_list = None
|
self.rare_word_list = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -202,20 +206,22 @@ class LibriHeavyAsrDataModule:
|
|||||||
default="small",
|
default="small",
|
||||||
help="Select the Libriheavy subset (small|medium|large)",
|
help="Select the Libriheavy subset (small|medium|large)",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--use-context-list",
|
"--use-context-list",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Use the context list of libri heavy",
|
help="Use the context list of libri heavy",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--min-count",
|
"--topk-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=7,
|
default=10000,
|
||||||
|
help="""The top-k words are identified as common words,
|
||||||
|
the rest as rare words""",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--with-decoding",
|
"--with-decoding",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -227,12 +233,12 @@ class LibriHeavyAsrDataModule:
|
|||||||
"--random-left-padding",
|
"--random-left-padding",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--rare-word-file",
|
"--rare-word-file",
|
||||||
type=str,
|
type=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--long-audio-cuts",
|
"--long-audio-cuts",
|
||||||
type=str,
|
type=str,
|
||||||
@ -257,13 +263,9 @@ class LibriHeavyAsrDataModule:
|
|||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
logging.info("Enable MUSAN")
|
logging.info("Enable MUSAN")
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
self.args.manifest_dir / "musan_cuts.jsonl.gz"
|
|
||||||
)
|
|
||||||
transforms.append(
|
transforms.append(
|
||||||
CutMix(
|
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Disable MUSAN")
|
logging.info("Disable MUSAN")
|
||||||
@ -285,9 +287,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
input_transforms = []
|
input_transforms = []
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
logging.info(
|
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||||
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
|
||||||
)
|
|
||||||
# Set the value of num_frame_masks according to Lhotse's version.
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
# In different Lhotse's versions, the default of num_frame_masks is
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
# different.
|
# different.
|
||||||
@ -332,9 +332,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = PromptASRDataset(
|
train = PromptASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
Fbank(FbankConfig(num_mel_bins=80))
|
|
||||||
),
|
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
@ -379,9 +377,10 @@ class LibriHeavyAsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self,
|
def valid_dataloaders(
|
||||||
|
self,
|
||||||
cuts_valid: CutSet,
|
cuts_valid: CutSet,
|
||||||
text_sampling_func: Callable[[List[str]], str] = None,
|
text_sampling_func: Callable[[List[str]], str] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.random_left_padding:
|
if self.args.random_left_padding:
|
||||||
@ -401,9 +400,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = PromptASRDataset(
|
validate = PromptASRDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
Fbank(FbankConfig(num_mel_bins=80))
|
|
||||||
),
|
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
text_sampling_func=text_sampling_func,
|
text_sampling_func=text_sampling_func,
|
||||||
rare_word_list=self.rare_word_list,
|
rare_word_list=self.rare_word_list,
|
||||||
@ -460,16 +457,16 @@ class LibriHeavyAsrDataModule:
|
|||||||
if self.args.use_context_list:
|
if self.args.use_context_list:
|
||||||
path = (
|
path = (
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_min_count_{self.args.min_count}.jsonl.gz"
|
/ f"libriheavy_cuts_{self.args.subset}_with_context_list_topk_{self.args.top_k}.jsonl.gz"
|
||||||
)
|
)
|
||||||
elif self.args.with_decoding:
|
elif self.args.with_decoding:
|
||||||
path = (
|
path = (
|
||||||
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
|
self.args.manifest_dir
|
||||||
|
/ f"libriheavy_cuts_{self.args.subset}_with_decoding.jsonl.gz"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = (
|
path = (
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir / f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
|
||||||
/ f"libriheavy_cuts_{self.args.subset}.jsonl.gz"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Loading manifest from {path}.")
|
logging.info(f"Loading manifest from {path}.")
|
||||||
@ -513,21 +510,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def npr1_dev_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get npr1 dev cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "npr1_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def npr1_test_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get npr1 test cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "npr1_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def long_audio_cuts(self) -> CutSet:
|
def long_audio_cuts(self) -> CutSet:
|
||||||
logging.info("About to get long audio cuts")
|
logging.info("About to get long audio cuts")
|
||||||
@ -535,11 +518,11 @@ class LibriHeavyAsrDataModule:
|
|||||||
self.args.long_audio_cuts,
|
self.args.long_audio_cuts,
|
||||||
)
|
)
|
||||||
return cuts
|
return cuts
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_dev_cuts(self) -> CutSet:
|
def test_dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test dev cuts")
|
logging.info("About to get test dev cuts")
|
||||||
cuts = load_manifest_lazy(
|
cuts = load_manifest_lazy(
|
||||||
self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz"
|
self.args.manifest_dir / "libriheavy_cuts_test_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
return cuts
|
return cuts
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
@ -12,9 +13,34 @@ from text_normalization import remove_non_alphabetic
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank",
|
||||||
|
help="Where are the manifest stored",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset", type=str, default="medium", help="Which subset to work with"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help="How many words to keep",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_facebook_biasing_list(
|
def get_facebook_biasing_list(
|
||||||
test_set: str,
|
test_set: str,
|
||||||
use_distractors: bool = False,
|
|
||||||
num_distractors: int = 100,
|
num_distractors: int = 100,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
# Get the biasing list from the meta paper: https://arxiv.org/pdf/2104.02194.pdf
|
||||||
@ -28,9 +54,9 @@ def get_facebook_biasing_list(
|
|||||||
raise ValueError(f"Unseen test set {test_set}")
|
raise ValueError(f"Unseen test set {test_set}")
|
||||||
else:
|
else:
|
||||||
if test_set == "test-clean":
|
if test_set == "test-clean":
|
||||||
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-clean.biasing_{num_distractors}.tsv"
|
||||||
elif test_set == "test-other":
|
elif test_set == "test-other":
|
||||||
biasing_file = "data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
biasing_file = f"data/context_biasing/fbai-speech/is21_deep_bias/ref/test-other.biasing_{num_distractors}.tsv"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unseen test set {test_set}")
|
raise ValueError(f"Unseen test set {test_set}")
|
||||||
|
|
||||||
@ -41,7 +67,7 @@ def get_facebook_biasing_list(
|
|||||||
output = dict()
|
output = dict()
|
||||||
for line in data:
|
for line in data:
|
||||||
id, _, l1, l2 = line.split("\t")
|
id, _, l1, l2 = line.split("\t")
|
||||||
if use_distractors:
|
if num_distractors > 0: # use distractors
|
||||||
biasing_list = ast.literal_eval(l2)
|
biasing_list = ast.literal_eval(l2)
|
||||||
else:
|
else:
|
||||||
biasing_list = ast.literal_eval(l1)
|
biasing_list = ast.literal_eval(l1)
|
||||||
@ -66,20 +92,25 @@ def brian_biasing_list(level: str):
|
|||||||
return biasing_dict
|
return biasing_dict
|
||||||
|
|
||||||
|
|
||||||
def get_rare_words(subset: str, min_count: int):
|
def get_rare_words(
|
||||||
|
subset: str = "medium",
|
||||||
|
top_k: int = 10000,
|
||||||
|
# min_count: int = 10000,
|
||||||
|
):
|
||||||
"""Get a list of rare words appearing less than `min_count` times
|
"""Get a list of rare words appearing less than `min_count` times
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subset: The dataset
|
subset: The dataset
|
||||||
min_count (int): Count of appearance
|
top_k (int): How many frequent words
|
||||||
"""
|
"""
|
||||||
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
txt_path = f"data/tmp/transcript_words_{subset}.txt"
|
||||||
rare_word_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
rare_word_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
||||||
|
|
||||||
if os.path.exists(rare_word_file):
|
if os.path.exists(rare_word_file):
|
||||||
print("File exists, do not proceed!")
|
print("File exists, do not proceed!")
|
||||||
return
|
return
|
||||||
print("Finding rare words in the manifest.")
|
|
||||||
|
print("---Identifying rare words in the manifest---")
|
||||||
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
count_file = f"data/tmp/transcript_words_{subset}_count.txt"
|
||||||
if not os.path.exists(count_file):
|
if not os.path.exists(count_file):
|
||||||
with open(txt_path, "r") as file:
|
with open(txt_path, "r") as file:
|
||||||
@ -94,9 +125,12 @@ def get_rare_words(subset: str, min_count: int):
|
|||||||
else:
|
else:
|
||||||
word_count[w] += 1
|
word_count[w] += 1
|
||||||
|
|
||||||
|
word_count = list(word_count.items()) # convert to a list of tuple
|
||||||
|
word_count = sorted(word_count, key=lambda w: int(w[1]), reverse=True)
|
||||||
with open(count_file, "w") as fout:
|
with open(count_file, "w") as fout:
|
||||||
for w in word_count:
|
for w, count in word_count:
|
||||||
fout.write(f"{w}\t{word_count[w]}\n")
|
fout.write(f"{w}\t{count}\n")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
word_count = {}
|
word_count = {}
|
||||||
with open(count_file, "r") as fin:
|
with open(count_file, "r") as fin:
|
||||||
@ -106,42 +140,45 @@ def get_rare_words(subset: str, min_count: int):
|
|||||||
|
|
||||||
print(f"A total of {len(word_count)} words appeared!")
|
print(f"A total of {len(word_count)} words appeared!")
|
||||||
rare_words = []
|
rare_words = []
|
||||||
for k in word_count:
|
for word, count in word_count[top_k:]:
|
||||||
if int(word_count[k]) <= min_count:
|
rare_words.append(word + "\n")
|
||||||
rare_words.append(k + "\n")
|
print(f"A total of {len(rare_words)} are identified as rare words.")
|
||||||
print(f"A total of {len(rare_words)} appeared <= {min_count} times")
|
|
||||||
|
|
||||||
with open(rare_word_file, "w") as f:
|
with open(rare_word_file, "w") as f:
|
||||||
f.writelines(rare_words)
|
f.writelines(rare_words)
|
||||||
|
|
||||||
|
|
||||||
def add_context_list_to_manifest(subset: str, min_count: int):
|
def add_context_list_to_manifest(
|
||||||
|
manifest_dir: str,
|
||||||
|
subset: str = "medium",
|
||||||
|
top_k: int = 10000,
|
||||||
|
):
|
||||||
"""Generate a context list of rare words for each utterance in the manifest
|
"""Generate a context list of rare words for each utterance in the manifest
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
manifest_dir: Where to store the manifest with context list
|
||||||
subset (str): Subset
|
subset (str): Subset
|
||||||
min_count (int): The min appearances
|
top_k (int): How many frequent words
|
||||||
|
|
||||||
"""
|
"""
|
||||||
rare_words_file = f"data/context_biasing/{subset}_rare_words_{min_count}.txt"
|
orig_manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}.jsonl.gz"
|
||||||
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}.jsonl.gz"
|
target_manifest_dir = orig_manifest_dir.replace(
|
||||||
|
".jsonl.gz", f"_with_context_list_topk_{top_k}.jsonl.gz"
|
||||||
target_manifest_dir = manifest_dir.replace(
|
|
||||||
".jsonl.gz", f"_with_context_list_min_count_{min_count}.jsonl.gz"
|
|
||||||
)
|
)
|
||||||
if os.path.exists(target_manifest_dir):
|
if os.path.exists(target_manifest_dir):
|
||||||
print(f"Target file exits at {target_manifest_dir}!")
|
print(f"Target file exits at {target_manifest_dir}!")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Reading rare words from {rare_words_file}")
|
rare_words_file = f"data/context_biasing/{subset}_rare_words_topk_{top_k}.txt"
|
||||||
|
print(f"---Reading rare words from {rare_words_file}---")
|
||||||
with open(rare_words_file, "r") as f:
|
with open(rare_words_file, "r") as f:
|
||||||
rare_words = f.read()
|
rare_words = f.read()
|
||||||
rare_words = rare_words.split("\n")
|
rare_words = rare_words.split("\n")
|
||||||
rare_words = set(rare_words)
|
rare_words = set(rare_words)
|
||||||
print(f"A total of {len(rare_words)} rare words!")
|
print(f"A total of {len(rare_words)} rare words!")
|
||||||
|
|
||||||
cuts = load_manifest_lazy(manifest_dir)
|
cuts = load_manifest_lazy(orig_manifest_dir)
|
||||||
print(f"Loaded manifest from {manifest_dir}")
|
print(f"Loaded manifest from {orig_manifest_dir}")
|
||||||
|
|
||||||
def _add_context(c: Cut):
|
def _add_context(c: Cut):
|
||||||
splits = (
|
splits = (
|
||||||
@ -157,15 +194,21 @@ def add_context_list_to_manifest(subset: str, min_count: int):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
cuts = cuts.map(_add_context)
|
cuts = cuts.map(_add_context)
|
||||||
|
print(f"---Saving manifest with context list to {target_manifest_dir}---")
|
||||||
cuts.to_file(target_manifest_dir)
|
cuts.to_file(target_manifest_dir)
|
||||||
print(f"Saved manifest with context list to {target_manifest_dir}")
|
print("Finished")
|
||||||
|
|
||||||
|
|
||||||
def check(subset: str, min_count: int):
|
def check(
|
||||||
# Used to show how many samples in the training set have a context list
|
manifest_dir: str,
|
||||||
print("Calculating the stats over the manifest")
|
subset: str = "medium",
|
||||||
manifest_dir = f"data/fbank/libriheavy_cuts_{subset}_with_context_list_min_count_{min_count}.jsonl.gz"
|
top_k: int = 10000,
|
||||||
|
):
|
||||||
|
# Show how many samples in the training set have a context list
|
||||||
|
# and the average length of context list
|
||||||
|
print("--- Calculating the stats over the manifest ---")
|
||||||
|
|
||||||
|
manifest_dir = f"{manifest_dir}/libriheavy_cuts_{subset}_with_context_list_topk_{top_k}.jsonl.gz"
|
||||||
cuts = load_manifest_lazy(manifest_dir)
|
cuts = load_manifest_lazy(manifest_dir)
|
||||||
total_cuts = len(cuts)
|
total_cuts = len(cuts)
|
||||||
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
has_context_list = [c.supervisions[0].context_list != "" for c in cuts]
|
||||||
@ -378,8 +421,19 @@ def write_error_stats(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
subset = "medium"
|
parser = get_parser()
|
||||||
min_count = 10
|
args = parser.parse_args()
|
||||||
get_rare_words(subset, min_count)
|
manifest_dir = args.manifest_dir
|
||||||
add_context_list_to_manifest(subset=subset, min_count=min_count)
|
subset = args.subset
|
||||||
check(subset=subset, min_count=min_count)
|
top_k = args.top_k
|
||||||
|
get_rare_words(subset=subset, top_k=top_k)
|
||||||
|
add_context_list_to_manifest(
|
||||||
|
manifest_dir=manifest_dir,
|
||||||
|
subset=subset,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
check(
|
||||||
|
manifest_dir=manifest_dir,
|
||||||
|
subset=subset,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user