Commit more scripts for gigaspeech kws recipe

This commit is contained in:
pkufool 2024-02-01 16:27:16 +08:00
parent e257b44763
commit 2addc6cba6
12 changed files with 2773 additions and 107 deletions

View File

@ -30,15 +30,15 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech_dev_test(): def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank") in_out_dir = Path("data/fbank")
# number of workers in dataloader # number of workers in dataloader
num_workers = 20 num_workers = 20
# number of seconds in a batch # number of seconds in a batch
batch_duration = 600 batch_duration = 1000
subsets = ("DEV", "TEST") subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}") logging.info(f"device: {device}")
for partition in subsets: for partition in subsets:
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz" cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file(): if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping") logging.info(f"{cuts_path} exists - skipping")
continue continue
raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz" raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}") logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path) cut_set = CutSet.from_file(raw_cuts_path)
@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}", storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers, num_workers=num_workers,
batch_duration=batch_duration, batch_duration=batch_duration,
overwrite=True, overwrite=True,
@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech_dev_test() compute_fbank_gigaspeech()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -99,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args):
idx = f"{i + 1}".zfill(num_digits) idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}") logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file(): if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping") logging.info(f"{cuts_path} exists - skipping")
continue continue
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}") logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path) cut_set = CutSet.from_file(raw_cuts_path)
@ -113,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch( cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}", storage_path=f"{output_dir}/gigaspeech_feats_XL_{idx}",
num_workers=args.num_workers, num_workers=args.num_workers,
batch_duration=args.batch_duration, batch_duration=args.batch_duration,
overwrite=True, overwrite=True,

View File

@ -16,17 +16,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from lhotse import CutSet, SupervisionSegment from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in: # Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh # https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Whether to use speed perturbation.",
)
return parser.parse_args()
def normalize_text( def normalize_text(
utt: str, utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
@ -42,7 +56,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None return oov_pattern.search(sup.text) is None
def preprocess_giga_speech(): def preprocess_giga_speech(args):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
@ -51,6 +65,10 @@ def preprocess_giga_speech():
"DEV", "DEV",
"TEST", "TEST",
"XL", "XL",
"L",
"M",
"S",
"XS",
) )
logging.info("Loading manifest (may take 4 minutes)") logging.info("Loading manifest (may take 4 minutes)")
@ -71,7 +89,7 @@ def preprocess_giga_speech():
for partition, m in manifests.items(): for partition, m in manifests.items():
logging.info(f"Processing {partition}") logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file(): if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping") logging.info(f"{partition} already exists - skipping")
continue continue
@ -94,11 +112,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the # Run data augmentation that needs to be done in the
# time domain. # time domain.
if partition not in ["DEV", "TEST"]: if partition not in ["DEV", "TEST"]:
logging.info( if args.perturb_speed:
f"Speed perturb for {partition} with factors 0.9 and 1.1 " logging.info(
"(Perturbing may take 8 minutes and saving may take 20 minutes)" f"Speed perturb for {partition} with factors 0.9 and 1.1 "
) "(Perturbing may take 8 minutes and saving may take 20 minutes)"
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) )
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
logging.info(f"Saving to {raw_cuts_path}") logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path) cut_set.to_file(raw_cuts_path)
@ -107,7 +128,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_giga_speech() args = get_args()
preprocess_giga_speech(args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1; exit 1;
fi fi
# Download XL, DEV and TEST sets by default. # Download XL, DEV and TEST sets by default.
lhotse download gigaspeech --subset auto --host tsinghua \ lhotse download gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech $dl_dir/password $dl_dir/GigaSpeech
fi fi
@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus # We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech # to $dl_dir/GigaSpeech
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare gigaspeech --subset auto -j $nj \ lhotse prepare gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests $dl_dir/GigaSpeech data/manifests
fi fi
@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)" log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
python3 ./local/compute_fbank_gigaspeech_dev_test.py python3 ./local/compute_fbank_gigaspeech.py
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare phone based lang" log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone lang_dir=data/lang_phone
mkdir -p $lang_dir mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/transcript_words.txt ]; then if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \ | jq '.text' \
@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare BPE based lang" log "Stage 10: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}
@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done done
fi fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 11: Prepare bigram P" log "Stage 12: Prepare bigram P"
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}
@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done done
fi fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 12: Prepare G" log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install # We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm # it using: pip install kaldilm
@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi fi
fi fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 13: Compile HLG" log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone ./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do for vocab_size in ${vocab_sizes[@]}; do

View File

@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
group.add_argument( group.add_argument(
"--num-buckets", "--num-buckets",
type=int, type=int,
default=30, default=100,
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
@ -312,8 +312,8 @@ class GigaSpeechAsrDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
buffer_size=self.args.num_buckets * 2000, buffer_size=self.args.num_buckets * 1000,
shuffle_buffer_size=self.args.num_buckets * 5000, shuffle_buffer_size=self.args.num_buckets * 3000,
) )
else: else:
logging.info("Using SimpleCutSampler.") logging.info("Using SimpleCutSampler.")

View File

@ -447,3 +447,38 @@ class GigaSpeechAsrDataModule:
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
) )
@lru_cache()
def libri_100_cuts(self) -> CutSet:
logging.info("About to get libri100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def fsc_train_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
)
@lru_cache()
def fsc_valid_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands valid cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
)
@lru_cache()
def fsc_test_small_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands small test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
)
@lru_cache()
def fsc_test_large_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands large test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
)

File diff suppressed because it is too large Load Diff

View File

@ -24,11 +24,10 @@ Usage:
--avg 15 \ --avg 15 \
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --keywords-file keywords.txt \
--beam-size 4 --beam-size 4
""" """
import argparse import argparse
import logging import logging
import math import math
@ -163,10 +162,17 @@ def get_parser():
help="File contains keywords.", help="File contains keywords.",
) )
parser.add_argument(
"--test-set",
type=str,
default="small",
help="small or large",
)
parser.add_argument( parser.add_argument(
"--keywords-score", "--keywords-score",
type=float, type=float,
default=3.0, default=1.5,
help=""" help="""
The default boosting score (token level) for keywords. it will boost the The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search. paths that match keywords to make them survive beam search.
@ -176,14 +182,21 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--keywords-threshold", "--keywords-threshold",
type=float, type=float,
default=0.75, default=0.35,
help="The default threshold (probability) to trigger the keyword.", help="The default threshold (probability) to trigger the keyword.",
) )
parser.add_argument(
"--keywords-version",
type=str,
default="",
help="The keywords configuration version, just to save results to different files.",
)
parser.add_argument( parser.add_argument(
"--num-tailing-blanks", "--num-tailing-blanks",
type=int, type=int,
default=8, default=1,
help="The number of tailing blanks should have after hitting one keyword.", help="The number of tailing blanks should have after hitting one keyword.",
) )
@ -261,7 +274,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
kws_graph=kws_graph, context_graph=kws_graph,
beam=params.beam, beam=params.beam,
num_tailing_blanks=params.num_tailing_blanks, num_tailing_blanks=params.num_tailing_blanks,
blank_penalty=params.blank_penalty, blank_penalty=params.blank_penalty,
@ -284,6 +297,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
kws_graph: ContextGraph, kws_graph: ContextGraph,
keywords: Set[str], keywords: Set[str],
test_only_keywords: bool,
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: ) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
"""Decode dataset. """Decode dataset.
@ -337,34 +351,65 @@ def decode_dataset(
ref_text = ref_text.upper() ref_text = ref_text.upper()
ref_words = ref_text.split() ref_words = ref_text.split()
hyp_words = [x[0] for x in hyp_words] hyp_words = [x[0] for x in hyp_words]
# for computing WER
this_batch.append((cut_id, ref_words, " ".join(hyp_words).split())) this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
hyp_set = set(hyp_words) hyp_set = set(hyp_words) # each item is a keyword phrase
hyp_str = " | ".join(hyp_words) if len(hyp_words) > 1:
logging.warning(
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
f"please check the transcript to see if it really has more "
f"than one keywords, if so consider splitting this audio and"
f"keep only one keyword for each audio."
)
hyp_str = " | ".join(
hyp_words
) # The triggered keywords for this utterance.
TP = False
FP = False
for x in hyp_set: for x in hyp_set:
assert x in keywords, x assert x in keywords, x # can only trigger keywords
if x in ref_text and x in keywords: if (test_only_keywords and x == ref_text) or (
metric["all"].TP += 1 not test_only_keywords and x in ref_text
):
TP = True
metric[x].TP += 1 metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})") metric[x].TP_list.append(f"({ref_text} -> {x})")
if x not in ref_text and x in keywords: if (test_only_keywords and x != ref_text) or (
metric["all"].FP += 1 not test_only_keywords and x not in ref_text
):
FP = True
metric[x].FP += 1 metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x})") metric[x].FP_list.append(f"({ref_text} -> {x})")
if TP:
metric["all"].TP += 1
if FP:
metric["all"].FP += 1
TN = True # all keywords are true negative then the summery is true negative.
FN = False
for x in keywords: for x in keywords:
if x not in ref_text and x not in hyp_set: if x not in ref_text and x not in hyp_set:
metric["all"].TN += 1
metric[x].TN += 1 metric[x].TN += 1
continue
if x in ref_text: TN = False
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
fn = True fn = True
for y in hyp_set: for y in hyp_set:
if y in ref_text: if (test_only_keywords and y == ref_text) or (
not test_only_keywords and y in ref_text
):
fn = False fn = False
break break
if fn and ref_text.endswith(x): if fn:
metric["all"].FN += 1 FN = True
metric[x].FN += 1 metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
if TN:
metric["all"].TN += 1
if FN:
metric["all"].FN += 1
results.extend(this_batch) results.extend(this_batch)
@ -396,16 +441,17 @@ def save_results(
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
print_s = ""
with open(metric_filename, "w") as of: with open(metric_filename, "w") as of:
width = 10 width = 10
for key, item in sorted( for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
): ):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (item.TP + 1) / (item.TP + item.FP + 1) precision = (
recall = (item.TP + 1) / (item.TP + item.FN + 1) 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
fpr = (item.FP + 1) / (item.FP + item.TN + 1) )
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
s = f"{key}:\n" s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
@ -414,12 +460,14 @@ def save_results(
s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n" s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
s += f"\tTP list: {' # '.join(item.TP_list)}\n" if key != "all":
s += f"\tFP list: {' # '.join(item.FP_list)}\n" s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n" s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n") of.write(s + "\n")
if key == "all": if key == "all":
logging.info(s) logging.info(s)
of.write(f"\n\n{params.keywords_config}")
logging.info("Wrote metric stats to {}".format(metric_filename)) logging.info("Wrote metric stats to {}".format(metric_filename))
@ -436,10 +484,11 @@ def main():
params.res_dir = params.exp_dir / "kws" params.res_dir = params.exp_dir / "kws"
params.suffix = params.test_set
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
if params.causal: if params.causal:
assert ( assert (
@ -456,6 +505,7 @@ def main():
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0: if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}" params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-version-{params.keywords_version}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -480,8 +530,10 @@ def main():
token_ids = [] token_ids = []
keywords_scores = [] keywords_scores = []
keywords_thresholds = [] keywords_thresholds = []
keywords_config = []
with open(params.keywords_file, "r") as f: with open(params.keywords_file, "r") as f:
for line in f.readlines(): for line in f.readlines():
keywords_config.append(line)
score = 0 score = 0
threshold = 0 threshold = 0
keyword = [] keyword = []
@ -501,6 +553,8 @@ def main():
keywords_scores.append(score) keywords_scores.append(score)
keywords_thresholds.append(threshold) keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
kws_graph = ContextGraph( kws_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold context_score=params.keywords_score, ac_threshold=params.keywords_threshold
) )
@ -605,24 +659,17 @@ def main():
test_cuts = gigaspeech.test_cuts() test_cuts = gigaspeech.test_cuts()
test_dl = gigaspeech.test_dataloaders(test_cuts) test_dl = gigaspeech.test_dataloaders(test_cuts)
def select_keyword_cuts(c: Cut): if params.test_set == "small":
text = c.supervisions[0].text test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
text = text.strip().upper() test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
return text in keywords test_sets = ["small-fsc", "test"]
test_dls = [test_fsc_small_dl, test_dl]
test_sc1_cuts = gigaspeech.test_speechcommands1_cuts() else:
test_sc2_cuts = gigaspeech.test_speechcommands2_cuts() assert params.test_set == "large", params.test_set
test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
test_fsc_cuts = gigaspeech.test_fluent_speechcommands_cuts() test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
test_fsc_cuts = test_fsc_cuts.filter(select_keyword_cuts) test_sets = ["large-fsc", "test"]
test_dls = [test_fsc_large_dl, test_dl]
test_sc1_dl = gigaspeech.test_dataloaders(test_sc1_cuts)
test_sc2_dl = gigaspeech.test_dataloaders(test_sc2_cuts)
test_fsc_dl = speechcommand.test_dataloaders(test_fsc_cuts)
test_sets = ["test-fsc", "test", "test-sc1", "test-sc2"]
test_dls = [test_fsc_dl, test_dl, test_sc1_dl, test_sc2_dl]
for test_set, test_dl in zip(test_sets, test_dls): for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset( results, metric = decode_dataset(
@ -632,6 +679,7 @@ def main():
sp=sp, sp=sp,
kws_graph=kws_graph, kws_graph=kws_graph,
keywords=keywords, keywords=keywords,
test_only_keywords="fsc" in test_set,
) )
save_results( save_results(

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/gigaspeech_scoring.py

View File

@ -126,7 +126,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,2,3,4,3,2", default="1,1,1,1,1,1",
help="Number of zipformer encoder layers per stack, comma separated.", help="Number of zipformer encoder layers per stack, comma separated.",
) )
@ -140,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="512,768,1024,1536,1024,768", default="192,192,192,192,192,192",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
) )
@ -154,7 +154,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="192,256,384,512,384,256", default="128,128,128,128,128,128",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.", help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
) )
@ -189,7 +189,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dim", "--encoder-unmasked-dim",
type=str, type=str,
default="192,192,256,256,256,192", default="128,128,128,128,128,128",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.", "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
) )
@ -205,14 +205,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--decoder-dim", "--decoder-dim",
type=int, type=int,
default=512, default=320,
help="Embedding dimension in the decoder model.", help="Embedding dimension in the decoder model.",
) )
parser.add_argument( parser.add_argument(
"--joiner-dim", "--joiner-dim",
type=int, type=int,
default=512, default=320,
help="""Dimension used in the joiner model. help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected Outputs from the encoder and decoder model are projected
to this dimension before adding. to this dimension before adding.
@ -222,7 +222,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--causal", "--causal",
type=str2bool, type=str2bool,
default=False, default=True,
help="If True, use causal version of model.", help="If True, use causal version of model.",
) )
@ -416,6 +416,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.", help="Accumulate stats on activations, print them and exit.",
) )
parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)
parser.add_argument( parser.add_argument(
"--inf-check", "--inf-check",
type=str2bool, type=str2bool,
@ -463,7 +474,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--use-fp16", "--use-fp16",
type=str2bool, type=str2bool,
default=False, default=True,
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
@ -1197,14 +1208,14 @@ def run(rank, world_size, args):
valid_cuts = valid_cuts.filter(remove_short_utt) valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts) valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics: if not params.print_diagnostics and params.scan_for_oom_batches:
# scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
# model=model, model=model,
# train_dl=train_dl, train_dl=train_dl,
# optimizer=optimizer, optimizer=optimizer,
# sp=sp, sp=sp,
# params=params, params=params,
# ) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:

View File

@ -966,7 +966,6 @@ def keywords_search(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
context_graph: ContextGraph, context_graph: ContextGraph,
beam: int = 4, beam: int = 4,
ac_threshold: float = 0.15,
num_tailing_blanks: int = 8, num_tailing_blanks: int = 8,
blank_penalty: float = 0, blank_penalty: float = 0,
) -> List[List[KeywordResult]]: ) -> List[List[KeywordResult]]:
@ -1077,6 +1076,8 @@ def keywords_search(
log_probs = probs.log() log_probs = probs.log()
probs = probs.reshape(-1)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1) vocab_size = log_probs.size(-1)
@ -1112,7 +1113,7 @@ def keywords_search(
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t) new_timestamp.append(t)
new_ac_probs.append(math.exp(hyp_probs[topk_indexes[k]])) new_ac_probs.append(hyp_probs[topk_indexes[k]])
( (
context_score, context_score,
new_context_state, new_context_state,
@ -1140,10 +1141,13 @@ def keywords_search(
ac_prob = ( ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
) )
# logging.info(
# f"ac prob : {ac_prob}, threshold : {matched_state.ac_threshold}"
# )
if ( if (
matched matched
and top_hyp.num_tailing_blanks > num_tailing_blanks and top_hyp.num_tailing_blanks > num_tailing_blanks
and ac_prob >= ac_threshold and ac_prob >= matched_state.ac_threshold
): ):
keyword = KeywordResult( keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :], hyps=top_hyp.ys[-matched_state.level :],
@ -1171,7 +1175,7 @@ def keywords_search(
ac_prob = ( ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
) )
if matched and ac_prob >= ac_threshold: if matched and ac_prob >= matched_state.ac_threshold:
keyword = KeywordResult( keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :], hyps=top_hyp.ys[-matched_state.level :],
timestamps=top_hyp.timestamp[-matched_state.level :], timestamps=top_hyp.timestamp[-matched_state.level :],