mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Commit more scripts for gigaspeech kws recipe
This commit is contained in:
parent
e257b44763
commit
2addc6cba6
@ -30,15 +30,15 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_gigaspeech_dev_test():
|
||||
def compute_fbank_gigaspeech():
|
||||
in_out_dir = Path("data/fbank")
|
||||
# number of workers in dataloader
|
||||
num_workers = 20
|
||||
|
||||
# number of seconds in a batch
|
||||
batch_duration = 600
|
||||
batch_duration = 1000
|
||||
|
||||
subsets = ("DEV", "TEST")
|
||||
subsets = ("L", "M", "S", "XS", "DEV", "TEST")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
for partition in subsets:
|
||||
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
|
||||
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{in_out_dir}/feats_{partition}",
|
||||
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
|
||||
num_workers=num_workers,
|
||||
batch_duration=batch_duration,
|
||||
overwrite=True,
|
||||
@ -80,7 +80,7 @@ def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_gigaspeech_dev_test()
|
||||
compute_fbank_gigaspeech()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
@ -99,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args):
|
||||
idx = f"{i + 1}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{num_splits}")
|
||||
|
||||
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
|
||||
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
|
||||
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {raw_cuts_path}")
|
||||
cut_set = CutSet.from_file(raw_cuts_path)
|
||||
@ -113,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/feats_XL_{idx}",
|
||||
storage_path=f"{output_dir}/gigaspeech_feats_XL_{idx}",
|
||||
num_workers=args.num_workers,
|
||||
batch_duration=args.batch_duration,
|
||||
overwrite=True,
|
||||
|
@ -16,17 +16,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Similar text filtering and normalization procedure as in:
|
||||
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use speed perturbation.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def normalize_text(
|
||||
utt: str,
|
||||
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
|
||||
@ -42,7 +56,7 @@ def has_no_oov(
|
||||
return oov_pattern.search(sup.text) is None
|
||||
|
||||
|
||||
def preprocess_giga_speech():
|
||||
def preprocess_giga_speech(args):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
@ -51,6 +65,10 @@ def preprocess_giga_speech():
|
||||
"DEV",
|
||||
"TEST",
|
||||
"XL",
|
||||
"L",
|
||||
"M",
|
||||
"S",
|
||||
"XS",
|
||||
)
|
||||
|
||||
logging.info("Loading manifest (may take 4 minutes)")
|
||||
@ -71,7 +89,7 @@ def preprocess_giga_speech():
|
||||
|
||||
for partition, m in manifests.items():
|
||||
logging.info(f"Processing {partition}")
|
||||
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
|
||||
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
|
||||
if raw_cuts_path.is_file():
|
||||
logging.info(f"{partition} already exists - skipping")
|
||||
continue
|
||||
@ -94,11 +112,14 @@ def preprocess_giga_speech():
|
||||
# Run data augmentation that needs to be done in the
|
||||
# time domain.
|
||||
if partition not in ["DEV", "TEST"]:
|
||||
logging.info(
|
||||
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
|
||||
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
|
||||
)
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
if args.perturb_speed:
|
||||
logging.info(
|
||||
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
|
||||
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
|
||||
)
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
logging.info(f"Saving to {raw_cuts_path}")
|
||||
cut_set.to_file(raw_cuts_path)
|
||||
|
||||
@ -107,7 +128,8 @@ def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
preprocess_giga_speech()
|
||||
args = get_args()
|
||||
preprocess_giga_speech(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
exit 1;
|
||||
fi
|
||||
# Download XL, DEV and TEST sets by default.
|
||||
lhotse download gigaspeech --subset auto --host tsinghua \
|
||||
lhotse download gigaspeech --subset XL \
|
||||
--subset L \
|
||||
--subset M \
|
||||
--subset S \
|
||||
--subset XS \
|
||||
--subset DEV \
|
||||
--subset TEST \
|
||||
--host tsinghua \
|
||||
$dl_dir/password $dl_dir/GigaSpeech
|
||||
fi
|
||||
|
||||
@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
# We assume that you have downloaded the GigaSpeech corpus
|
||||
# to $dl_dir/GigaSpeech
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare gigaspeech --subset auto -j $nj \
|
||||
lhotse prepare gigaspeech --subset XL \
|
||||
--subset L \
|
||||
--subset M \
|
||||
--subset S \
|
||||
--subset XS \
|
||||
--subset DEV \
|
||||
--subset TEST \
|
||||
-j $nj \
|
||||
$dl_dir/GigaSpeech data/manifests
|
||||
fi
|
||||
|
||||
@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
|
||||
python3 ./local/compute_fbank_gigaspeech_dev_test.py
|
||||
log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
|
||||
python3 ./local/compute_fbank_gigaspeech.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Prepare phone based lang"
|
||||
log "Stage 9: Prepare transcript_words.txt and words.txt"
|
||||
lang_dir=data/lang_phone
|
||||
mkdir -p $lang_dir
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/lexicon.txt |
|
||||
sort | uniq > $lang_dir/lexicon.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
|
||||
| jq '.text' \
|
||||
@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Prepare BPE based lang"
|
||||
log "Stage 10: Prepare phone based lang"
|
||||
lang_dir=data/lang_phone
|
||||
mkdir -p $lang_dir
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/lexicon.txt |
|
||||
sort | uniq > $lang_dir/lexicon.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Prepare BPE based lang"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Prepare bigram P"
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 12: Prepare bigram P"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 12: Prepare G"
|
||||
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
log "Stage 13: Prepare G"
|
||||
# We assume you have installed kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
log "Stage 13: Compile HLG"
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "Stage 14: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir data/lang_phone
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
|
@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
default=100,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
@ -312,8 +312,8 @@ class GigaSpeechAsrDataModule:
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
buffer_size=self.args.num_buckets * 1000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 3000,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
|
@ -447,3 +447,38 @@ class GigaSpeechAsrDataModule:
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def libri_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get libri100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def fsc_train_cuts(self) -> CutSet:
|
||||
logging.info("About to get fluent speech commands train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def fsc_valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get fluent speech commands valid cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def fsc_test_small_cuts(self) -> CutSet:
|
||||
logging.info("About to get fluent speech commands small test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def fsc_test_large_cuts(self) -> CutSet:
|
||||
logging.info("About to get fluent speech commands large test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
|
||||
)
|
||||
|
1065
egs/gigaspeech/KWS/zipformer/decode-asr.py
Executable file
1065
egs/gigaspeech/KWS/zipformer/decode-asr.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -24,11 +24,10 @@ Usage:
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--keywords-file keywords.txt \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@ -163,10 +162,17 @@ def get_parser():
|
||||
help="File contains keywords.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-set",
|
||||
type=str,
|
||||
default="small",
|
||||
help="small or large",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-score",
|
||||
type=float,
|
||||
default=3.0,
|
||||
default=1.5,
|
||||
help="""
|
||||
The default boosting score (token level) for keywords. it will boost the
|
||||
paths that match keywords to make them survive beam search.
|
||||
@ -176,14 +182,21 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--keywords-threshold",
|
||||
type=float,
|
||||
default=0.75,
|
||||
default=0.35,
|
||||
help="The default threshold (probability) to trigger the keyword.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-version",
|
||||
type=str,
|
||||
default="",
|
||||
help="The keywords configuration version, just to save results to different files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tailing-blanks",
|
||||
type=int,
|
||||
default=8,
|
||||
default=1,
|
||||
help="The number of tailing blanks should have after hitting one keyword.",
|
||||
)
|
||||
|
||||
@ -261,7 +274,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
kws_graph=kws_graph,
|
||||
context_graph=kws_graph,
|
||||
beam=params.beam,
|
||||
num_tailing_blanks=params.num_tailing_blanks,
|
||||
blank_penalty=params.blank_penalty,
|
||||
@ -284,6 +297,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
kws_graph: ContextGraph,
|
||||
keywords: Set[str],
|
||||
test_only_keywords: bool,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -337,34 +351,65 @@ def decode_dataset(
|
||||
ref_text = ref_text.upper()
|
||||
ref_words = ref_text.split()
|
||||
hyp_words = [x[0] for x in hyp_words]
|
||||
# for computing WER
|
||||
this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
|
||||
hyp_set = set(hyp_words)
|
||||
hyp_str = " | ".join(hyp_words)
|
||||
hyp_set = set(hyp_words) # each item is a keyword phrase
|
||||
if len(hyp_words) > 1:
|
||||
logging.warning(
|
||||
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
|
||||
f"please check the transcript to see if it really has more "
|
||||
f"than one keywords, if so consider splitting this audio and"
|
||||
f"keep only one keyword for each audio."
|
||||
)
|
||||
hyp_str = " | ".join(
|
||||
hyp_words
|
||||
) # The triggered keywords for this utterance.
|
||||
TP = False
|
||||
FP = False
|
||||
for x in hyp_set:
|
||||
assert x in keywords, x
|
||||
if x in ref_text and x in keywords:
|
||||
metric["all"].TP += 1
|
||||
assert x in keywords, x # can only trigger keywords
|
||||
if (test_only_keywords and x == ref_text) or (
|
||||
not test_only_keywords and x in ref_text
|
||||
):
|
||||
TP = True
|
||||
metric[x].TP += 1
|
||||
metric[x].TP_list.append(f"({ref_text} -> {x})")
|
||||
if x not in ref_text and x in keywords:
|
||||
metric["all"].FP += 1
|
||||
if (test_only_keywords and x != ref_text) or (
|
||||
not test_only_keywords and x not in ref_text
|
||||
):
|
||||
FP = True
|
||||
metric[x].FP += 1
|
||||
metric[x].FP_list.append(f"({ref_text} -> {x})")
|
||||
if TP:
|
||||
metric["all"].TP += 1
|
||||
if FP:
|
||||
metric["all"].FP += 1
|
||||
TN = True # all keywords are true negative then the summery is true negative.
|
||||
FN = False
|
||||
for x in keywords:
|
||||
if x not in ref_text and x not in hyp_set:
|
||||
metric["all"].TN += 1
|
||||
metric[x].TN += 1
|
||||
continue
|
||||
|
||||
if x in ref_text:
|
||||
TN = False
|
||||
if (test_only_keywords and x == ref_text) or (
|
||||
not test_only_keywords and x in ref_text
|
||||
):
|
||||
fn = True
|
||||
for y in hyp_set:
|
||||
if y in ref_text:
|
||||
if (test_only_keywords and y == ref_text) or (
|
||||
not test_only_keywords and y in ref_text
|
||||
):
|
||||
fn = False
|
||||
break
|
||||
if fn and ref_text.endswith(x):
|
||||
metric["all"].FN += 1
|
||||
if fn:
|
||||
FN = True
|
||||
metric[x].FN += 1
|
||||
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
|
||||
if TN:
|
||||
metric["all"].TN += 1
|
||||
if FN:
|
||||
metric["all"].FN += 1
|
||||
|
||||
results.extend(this_batch)
|
||||
|
||||
@ -396,16 +441,17 @@ def save_results(
|
||||
|
||||
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
|
||||
|
||||
print_s = ""
|
||||
with open(metric_filename, "w") as of:
|
||||
width = 10
|
||||
for key, item in sorted(
|
||||
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
|
||||
):
|
||||
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
|
||||
precision = (item.TP + 1) / (item.TP + item.FP + 1)
|
||||
recall = (item.TP + 1) / (item.TP + item.FN + 1)
|
||||
fpr = (item.FP + 1) / (item.FP + item.TN + 1)
|
||||
precision = (
|
||||
0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
|
||||
)
|
||||
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
|
||||
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
|
||||
s = f"{key}:\n"
|
||||
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
|
||||
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
|
||||
@ -414,12 +460,14 @@ def save_results(
|
||||
s += f"\tRecall(PPR): {recall:.3f}\n"
|
||||
s += f"\tFPR: {fpr:.3f}\n"
|
||||
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
|
||||
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
||||
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
||||
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
|
||||
if key != "all":
|
||||
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
||||
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
||||
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
|
||||
of.write(s + "\n")
|
||||
if key == "all":
|
||||
logging.info(s)
|
||||
of.write(f"\n\n{params.keywords_config}")
|
||||
|
||||
logging.info("Wrote metric stats to {}".format(metric_filename))
|
||||
|
||||
@ -436,10 +484,11 @@ def main():
|
||||
|
||||
params.res_dir = params.exp_dir / "kws"
|
||||
|
||||
params.suffix = params.test_set
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.causal:
|
||||
assert (
|
||||
@ -456,6 +505,7 @@ def main():
|
||||
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
||||
if params.blank_penalty != 0:
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
params.suffix += f"-version-{params.keywords_version}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -480,8 +530,10 @@ def main():
|
||||
token_ids = []
|
||||
keywords_scores = []
|
||||
keywords_thresholds = []
|
||||
keywords_config = []
|
||||
with open(params.keywords_file, "r") as f:
|
||||
for line in f.readlines():
|
||||
keywords_config.append(line)
|
||||
score = 0
|
||||
threshold = 0
|
||||
keyword = []
|
||||
@ -501,6 +553,8 @@ def main():
|
||||
keywords_scores.append(score)
|
||||
keywords_thresholds.append(threshold)
|
||||
|
||||
params.keywords_config = "".join(keywords_config)
|
||||
|
||||
kws_graph = ContextGraph(
|
||||
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
||||
)
|
||||
@ -605,24 +659,17 @@ def main():
|
||||
test_cuts = gigaspeech.test_cuts()
|
||||
test_dl = gigaspeech.test_dataloaders(test_cuts)
|
||||
|
||||
def select_keyword_cuts(c: Cut):
|
||||
text = c.supervisions[0].text
|
||||
text = text.strip().upper()
|
||||
return text in keywords
|
||||
|
||||
test_sc1_cuts = gigaspeech.test_speechcommands1_cuts()
|
||||
test_sc2_cuts = gigaspeech.test_speechcommands2_cuts()
|
||||
|
||||
test_fsc_cuts = gigaspeech.test_fluent_speechcommands_cuts()
|
||||
test_fsc_cuts = test_fsc_cuts.filter(select_keyword_cuts)
|
||||
|
||||
test_sc1_dl = gigaspeech.test_dataloaders(test_sc1_cuts)
|
||||
test_sc2_dl = gigaspeech.test_dataloaders(test_sc2_cuts)
|
||||
|
||||
test_fsc_dl = speechcommand.test_dataloaders(test_fsc_cuts)
|
||||
|
||||
test_sets = ["test-fsc", "test", "test-sc1", "test-sc2"]
|
||||
test_dls = [test_fsc_dl, test_dl, test_sc1_dl, test_sc2_dl]
|
||||
if params.test_set == "small":
|
||||
test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
|
||||
test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
|
||||
test_sets = ["small-fsc", "test"]
|
||||
test_dls = [test_fsc_small_dl, test_dl]
|
||||
else:
|
||||
assert params.test_set == "large", params.test_set
|
||||
test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
|
||||
test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
|
||||
test_sets = ["large-fsc", "test"]
|
||||
test_dls = [test_fsc_large_dl, test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results, metric = decode_dataset(
|
||||
@ -632,6 +679,7 @@ def main():
|
||||
sp=sp,
|
||||
kws_graph=kws_graph,
|
||||
keywords=keywords,
|
||||
test_only_keywords="fsc" in test_set,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
1461
egs/gigaspeech/KWS/zipformer/finetune.py
Executable file
1461
egs/gigaspeech/KWS/zipformer/finetune.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
Symbolic link
1
egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/gigaspeech_scoring.py
|
@ -126,7 +126,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,2,3,4,3,2",
|
||||
default="1,1,1,1,1,1",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
@ -140,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
default="512,768,1024,1536,1024,768",
|
||||
default="192,192,192,192,192,192",
|
||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||
)
|
||||
|
||||
@ -154,7 +154,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="192,256,384,512,384,256",
|
||||
default="128,128,128,128,128,128",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
@ -189,7 +189,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
default="192,192,256,256,256,192",
|
||||
default="128,128,128,128,128,128",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
|
||||
)
|
||||
@ -205,14 +205,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
default=320,
|
||||
help="Embedding dimension in the decoder model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
default=320,
|
||||
help="""Dimension used in the joiner model.
|
||||
Outputs from the encoder and decoder model are projected
|
||||
to this dimension before adding.
|
||||
@ -222,7 +222,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--causal",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="If True, use causal version of model.",
|
||||
)
|
||||
|
||||
@ -416,6 +416,17 @@ def get_parser():
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scan-for-oom-batches",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""
|
||||
Whether to scan for oom batches before training, this is helpful for
|
||||
finding the suitable max_duration, you only need to run it once.
|
||||
Caution: a little time consuming.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
@ -463,7 +474,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
@ -1197,14 +1208,14 @@ def run(rank, world_size, args):
|
||||
valid_cuts = valid_cuts.filter(remove_short_utt)
|
||||
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
if not params.print_diagnostics and params.scan_for_oom_batches:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
@ -966,7 +966,6 @@ def keywords_search(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: ContextGraph,
|
||||
beam: int = 4,
|
||||
ac_threshold: float = 0.15,
|
||||
num_tailing_blanks: int = 8,
|
||||
blank_penalty: float = 0,
|
||||
) -> List[List[KeywordResult]]:
|
||||
@ -1077,6 +1076,8 @@ def keywords_search(
|
||||
|
||||
log_probs = probs.log()
|
||||
|
||||
probs = probs.reshape(-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
@ -1112,7 +1113,7 @@ def keywords_search(
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
new_ac_probs.append(math.exp(hyp_probs[topk_indexes[k]]))
|
||||
new_ac_probs.append(hyp_probs[topk_indexes[k]])
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
@ -1140,10 +1141,13 @@ def keywords_search(
|
||||
ac_prob = (
|
||||
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
||||
)
|
||||
# logging.info(
|
||||
# f"ac prob : {ac_prob}, threshold : {matched_state.ac_threshold}"
|
||||
# )
|
||||
if (
|
||||
matched
|
||||
and top_hyp.num_tailing_blanks > num_tailing_blanks
|
||||
and ac_prob >= ac_threshold
|
||||
and ac_prob >= matched_state.ac_threshold
|
||||
):
|
||||
keyword = KeywordResult(
|
||||
hyps=top_hyp.ys[-matched_state.level :],
|
||||
@ -1171,7 +1175,7 @@ def keywords_search(
|
||||
ac_prob = (
|
||||
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
||||
)
|
||||
if matched and ac_prob >= ac_threshold:
|
||||
if matched and ac_prob >= matched_state.ac_threshold:
|
||||
keyword = KeywordResult(
|
||||
hyps=top_hyp.ys[-matched_state.level :],
|
||||
timestamps=top_hyp.timestamp[-matched_state.level :],
|
||||
|
Loading…
x
Reference in New Issue
Block a user