Commit more scripts for gigaspeech kws recipe

This commit is contained in:
pkufool 2024-02-01 16:27:16 +08:00
parent e257b44763
commit 2addc6cba6
12 changed files with 2773 additions and 107 deletions

View File

@ -30,15 +30,15 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_gigaspeech_dev_test():
def compute_fbank_gigaspeech():
in_out_dir = Path("data/fbank")
# number of workers in dataloader
num_workers = 20
# number of seconds in a batch
batch_duration = 600
batch_duration = 1000
subsets = ("DEV", "TEST")
subsets = ("L", "M", "S", "XS", "DEV", "TEST")
device = torch.device("cpu")
if torch.cuda.is_available():
@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test():
logging.info(f"device: {device}")
for partition in subsets:
cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz"
cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test():
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{in_out_dir}/feats_{partition}",
storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}",
num_workers=num_workers,
batch_duration=batch_duration,
overwrite=True,
@ -80,7 +80,7 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_gigaspeech_dev_test()
compute_fbank_gigaspeech()
if __name__ == "__main__":

View File

@ -99,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args):
idx = f"{i + 1}".zfill(num_digits)
logging.info(f"Processing {idx}/{num_splits}")
cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz"
cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz"
logging.info(f"Loading {raw_cuts_path}")
cut_set = CutSet.from_file(raw_cuts_path)
@ -113,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args):
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/feats_XL_{idx}",
storage_path=f"{output_dir}/gigaspeech_feats_XL_{idx}",
num_workers=args.num_workers,
batch_duration=args.batch_duration,
overwrite=True,

View File

@ -16,17 +16,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import re
from pathlib import Path
from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=False,
help="Whether to use speed perturbation.",
)
return parser.parse_args()
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
@ -42,7 +56,7 @@ def has_no_oov(
return oov_pattern.search(sup.text) is None
def preprocess_giga_speech():
def preprocess_giga_speech(args):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
output_dir.mkdir(exist_ok=True)
@ -51,6 +65,10 @@ def preprocess_giga_speech():
"DEV",
"TEST",
"XL",
"L",
"M",
"S",
"XS",
)
logging.info("Loading manifest (may take 4 minutes)")
@ -71,7 +89,7 @@ def preprocess_giga_speech():
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
@ -94,11 +112,14 @@ def preprocess_giga_speech():
# Run data augmentation that needs to be done in the
# time domain.
if partition not in ["DEV", "TEST"]:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
if args.perturb_speed:
logging.info(
f"Speed perturb for {partition} with factors 0.9 and 1.1 "
"(Perturbing may take 8 minutes and saving may take 20 minutes)"
)
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
@ -107,7 +128,8 @@ def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_giga_speech()
args = get_args()
preprocess_giga_speech(args)
if __name__ == "__main__":

View File

@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
exit 1;
fi
# Download XL, DEV and TEST sets by default.
lhotse download gigaspeech --subset auto --host tsinghua \
lhotse download gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
--host tsinghua \
$dl_dir/password $dl_dir/GigaSpeech
fi
@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We assume that you have downloaded the GigaSpeech corpus
# to $dl_dir/GigaSpeech
mkdir -p data/manifests
lhotse prepare gigaspeech --subset auto -j $nj \
lhotse prepare gigaspeech --subset XL \
--subset L \
--subset M \
--subset S \
--subset XS \
--subset DEV \
--subset TEST \
-j $nj \
$dl_dir/GigaSpeech data/manifests
fi
@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)"
python3 ./local/compute_fbank_gigaspeech_dev_test.py
log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech."
python3 ./local/compute_fbank_gigaspeech.py
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Prepare phone based lang"
log "Stage 9: Prepare transcript_words.txt and words.txt"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
if [ ! -f $lang_dir/transcript_words.txt ]; then
gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \
| jq '.text' \
@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare BPE based lang"
log "Stage 10: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare BPE based lang"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare bigram P"
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare bigram P"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
done
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare G"
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Prepare G"
# We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm
@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
fi
fi
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
log "Stage 13: Compile HLG"
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Compile HLG"
./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do

View File

@ -105,7 +105,7 @@ class GigaSpeechAsrDataModule:
group.add_argument(
"--num-buckets",
type=int,
default=30,
default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
@ -312,8 +312,8 @@ class GigaSpeechAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
buffer_size=self.args.num_buckets * 1000,
shuffle_buffer_size=self.args.num_buckets * 3000,
)
else:
logging.info("Using SimpleCutSampler.")

View File

@ -447,3 +447,38 @@ class GigaSpeechAsrDataModule:
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)
@lru_cache()
def libri_100_cuts(self) -> CutSet:
logging.info("About to get libri100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def fsc_train_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz"
)
@lru_cache()
def fsc_valid_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands valid cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz"
)
@lru_cache()
def fsc_test_small_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands small test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz"
)
@lru_cache()
def fsc_test_large_cuts(self) -> CutSet:
logging.info("About to get fluent speech commands large test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz"
)

File diff suppressed because it is too large Load Diff

View File

@ -24,11 +24,10 @@ Usage:
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--keywords-file keywords.txt \
--beam-size 4
"""
import argparse
import logging
import math
@ -163,10 +162,17 @@ def get_parser():
help="File contains keywords.",
)
parser.add_argument(
"--test-set",
type=str,
default="small",
help="small or large",
)
parser.add_argument(
"--keywords-score",
type=float,
default=3.0,
default=1.5,
help="""
The default boosting score (token level) for keywords. it will boost the
paths that match keywords to make them survive beam search.
@ -176,14 +182,21 @@ def get_parser():
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.75,
default=0.35,
help="The default threshold (probability) to trigger the keyword.",
)
parser.add_argument(
"--keywords-version",
type=str,
default="",
help="The keywords configuration version, just to save results to different files.",
)
parser.add_argument(
"--num-tailing-blanks",
type=int,
default=8,
default=1,
help="The number of tailing blanks should have after hitting one keyword.",
)
@ -261,7 +274,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
kws_graph=kws_graph,
context_graph=kws_graph,
beam=params.beam,
num_tailing_blanks=params.num_tailing_blanks,
blank_penalty=params.blank_penalty,
@ -284,6 +297,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
kws_graph: ContextGraph,
keywords: Set[str],
test_only_keywords: bool,
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
"""Decode dataset.
@ -337,34 +351,65 @@ def decode_dataset(
ref_text = ref_text.upper()
ref_words = ref_text.split()
hyp_words = [x[0] for x in hyp_words]
# for computing WER
this_batch.append((cut_id, ref_words, " ".join(hyp_words).split()))
hyp_set = set(hyp_words)
hyp_str = " | ".join(hyp_words)
hyp_set = set(hyp_words) # each item is a keyword phrase
if len(hyp_words) > 1:
logging.warning(
f"Cut {cut_id} triggers more than one keywords : {hyp_words},"
f"please check the transcript to see if it really has more "
f"than one keywords, if so consider splitting this audio and"
f"keep only one keyword for each audio."
)
hyp_str = " | ".join(
hyp_words
) # The triggered keywords for this utterance.
TP = False
FP = False
for x in hyp_set:
assert x in keywords, x
if x in ref_text and x in keywords:
metric["all"].TP += 1
assert x in keywords, x # can only trigger keywords
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
TP = True
metric[x].TP += 1
metric[x].TP_list.append(f"({ref_text} -> {x})")
if x not in ref_text and x in keywords:
metric["all"].FP += 1
if (test_only_keywords and x != ref_text) or (
not test_only_keywords and x not in ref_text
):
FP = True
metric[x].FP += 1
metric[x].FP_list.append(f"({ref_text} -> {x})")
if TP:
metric["all"].TP += 1
if FP:
metric["all"].FP += 1
TN = True # all keywords are true negative then the summery is true negative.
FN = False
for x in keywords:
if x not in ref_text and x not in hyp_set:
metric["all"].TN += 1
metric[x].TN += 1
continue
if x in ref_text:
TN = False
if (test_only_keywords and x == ref_text) or (
not test_only_keywords and x in ref_text
):
fn = True
for y in hyp_set:
if y in ref_text:
if (test_only_keywords and y == ref_text) or (
not test_only_keywords and y in ref_text
):
fn = False
break
if fn and ref_text.endswith(x):
metric["all"].FN += 1
if fn:
FN = True
metric[x].FN += 1
metric[x].FN_list.append(f"({ref_text} -> {hyp_str})")
if TN:
metric["all"].TN += 1
if FN:
metric["all"].FN += 1
results.extend(this_batch)
@ -396,16 +441,17 @@ def save_results(
metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt"
print_s = ""
with open(metric_filename, "w") as of:
width = 10
for key, item in sorted(
metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True
):
acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN)
precision = (item.TP + 1) / (item.TP + item.FP + 1)
recall = (item.TP + 1) / (item.TP + item.FN + 1)
fpr = (item.FP + 1) / (item.FP + item.TN + 1)
precision = (
0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP)
)
recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN)
fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN)
s = f"{key}:\n"
s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n"
s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n"
@ -414,12 +460,14 @@ def save_results(
s += f"\tRecall(PPR): {recall:.3f}\n"
s += f"\tFPR: {fpr:.3f}\n"
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
if key != "all":
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
s += f"\tFN list: {' # '.join(item.FN_list)}\n"
of.write(s + "\n")
if key == "all":
logging.info(s)
of.write(f"\n\n{params.keywords_config}")
logging.info("Wrote metric stats to {}".format(metric_filename))
@ -436,10 +484,11 @@ def main():
params.res_dir = params.exp_dir / "kws"
params.suffix = params.test_set
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix += f"-iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
@ -456,6 +505,7 @@ def main():
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
if params.blank_penalty != 0:
params.suffix += f"-blank-penalty-{params.blank_penalty}"
params.suffix += f"-version-{params.keywords_version}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -480,8 +530,10 @@ def main():
token_ids = []
keywords_scores = []
keywords_thresholds = []
keywords_config = []
with open(params.keywords_file, "r") as f:
for line in f.readlines():
keywords_config.append(line)
score = 0
threshold = 0
keyword = []
@ -501,6 +553,8 @@ def main():
keywords_scores.append(score)
keywords_thresholds.append(threshold)
params.keywords_config = "".join(keywords_config)
kws_graph = ContextGraph(
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
)
@ -605,24 +659,17 @@ def main():
test_cuts = gigaspeech.test_cuts()
test_dl = gigaspeech.test_dataloaders(test_cuts)
def select_keyword_cuts(c: Cut):
text = c.supervisions[0].text
text = text.strip().upper()
return text in keywords
test_sc1_cuts = gigaspeech.test_speechcommands1_cuts()
test_sc2_cuts = gigaspeech.test_speechcommands2_cuts()
test_fsc_cuts = gigaspeech.test_fluent_speechcommands_cuts()
test_fsc_cuts = test_fsc_cuts.filter(select_keyword_cuts)
test_sc1_dl = gigaspeech.test_dataloaders(test_sc1_cuts)
test_sc2_dl = gigaspeech.test_dataloaders(test_sc2_cuts)
test_fsc_dl = speechcommand.test_dataloaders(test_fsc_cuts)
test_sets = ["test-fsc", "test", "test-sc1", "test-sc2"]
test_dls = [test_fsc_dl, test_dl, test_sc1_dl, test_sc2_dl]
if params.test_set == "small":
test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts()
test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts)
test_sets = ["small-fsc", "test"]
test_dls = [test_fsc_small_dl, test_dl]
else:
assert params.test_set == "large", params.test_set
test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts()
test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts)
test_sets = ["large-fsc", "test"]
test_dls = [test_fsc_large_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results, metric = decode_dataset(
@ -632,6 +679,7 @@ def main():
sp=sp,
kws_graph=kws_graph,
keywords=keywords,
test_only_keywords="fsc" in test_set,
)
save_results(

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../ASR/zipformer/gigaspeech_scoring.py

View File

@ -126,7 +126,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=str,
default="2,2,3,4,3,2",
default="1,1,1,1,1,1",
help="Number of zipformer encoder layers per stack, comma separated.",
)
@ -140,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--feedforward-dim",
type=str,
default="512,768,1024,1536,1024,768",
default="192,192,192,192,192,192",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
)
@ -154,7 +154,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-dim",
type=str,
default="192,256,384,512,384,256",
default="128,128,128,128,128,128",
help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
)
@ -189,7 +189,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="192,192,256,256,256,192",
default="128,128,128,128,128,128",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
)
@ -205,14 +205,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
default=320,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
default=320,
help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected
to this dimension before adding.
@ -222,7 +222,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--causal",
type=str2bool,
default=False,
default=True,
help="If True, use causal version of model.",
)
@ -416,6 +416,17 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--scan-for-oom-batches",
type=str2bool,
default=False,
help="""
Whether to scan for oom batches before training, this is helpful for
finding the suitable max_duration, you only need to run it once.
Caution: a little time consuming.
""",
)
parser.add_argument(
"--inf-check",
type=str2bool,
@ -463,7 +474,7 @@ def get_parser():
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
default=True,
help="Whether to use half precision training.",
)
@ -1197,14 +1208,14 @@ def run(rank, world_size, args):
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
if not params.print_diagnostics and params.scan_for_oom_batches:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:

View File

@ -966,7 +966,6 @@ def keywords_search(
encoder_out_lens: torch.Tensor,
context_graph: ContextGraph,
beam: int = 4,
ac_threshold: float = 0.15,
num_tailing_blanks: int = 8,
blank_penalty: float = 0,
) -> List[List[KeywordResult]]:
@ -1077,6 +1076,8 @@ def keywords_search(
log_probs = probs.log()
probs = probs.reshape(-1)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
@ -1112,7 +1113,7 @@ def keywords_search(
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_ac_probs.append(math.exp(hyp_probs[topk_indexes[k]]))
new_ac_probs.append(hyp_probs[topk_indexes[k]])
(
context_score,
new_context_state,
@ -1140,10 +1141,13 @@ def keywords_search(
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
)
# logging.info(
# f"ac prob : {ac_prob}, threshold : {matched_state.ac_threshold}"
# )
if (
matched
and top_hyp.num_tailing_blanks > num_tailing_blanks
and ac_prob >= ac_threshold
and ac_prob >= matched_state.ac_threshold
):
keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :],
@ -1171,7 +1175,7 @@ def keywords_search(
ac_prob = (
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
)
if matched and ac_prob >= ac_threshold:
if matched and ac_prob >= matched_state.ac_threshold:
keyword = KeywordResult(
hyps=top_hyp.ys[-matched_state.level :],
timestamps=top_hyp.timestamp[-matched_state.level :],