add decoding

This commit is contained in:
Desh Raj 2022-03-16 13:25:35 -04:00
parent 0c27ba45e7
commit b2a2b1d387
3 changed files with 36 additions and 33 deletions

View File

@ -104,7 +104,7 @@ class SPGISpeechAsrDataModule:
group.add_argument(
"--max-duration",
type=int,
default=140.0,
default=100.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
@ -296,7 +296,7 @@ class SPGISpeechAsrDataModule:
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
def val_cuts(self) -> CutSet:
logging.info("About to get SPGISpeech val cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_val.jsonl.gz")

View File

@ -23,10 +23,11 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
from numpy import True_
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import SPGISpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
@ -131,7 +132,7 @@ def get_parser():
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
default="data/lang_bpe_5000",
help="The lang dir",
)
@ -140,7 +141,7 @@ def get_parser():
type=str,
default="data/lm",
help="""The LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
It should contain either G_3_gram.pt or G_3_gram.fst.txt
""",
)
@ -420,7 +421,7 @@ def decode_dataset(
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
is a 3-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
@ -462,9 +463,7 @@ def decode_dataset(
results[lm_scale].extend(this_batch)
else:
assert (
len(results) > 0
), "It should not decode to empty in the first batch!"
assert len(results) > 0, "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
@ -479,9 +478,7 @@ def decode_dataset(
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
@ -512,9 +509,7 @@ def save_results(
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@ -534,7 +529,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
SPGISpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
@ -570,7 +565,7 @@ def main():
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
modified=True, # Use modified topology since vocab size is large
device=device,
)
bpe_model = spm.SentencePieceProcessor()
@ -591,10 +586,10 @@ def main():
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
if not (params.lm_dir / "G_3_gram.pt").is_file():
logging.info("Loading G_3_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
with open(params.lm_dir / "G_3_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
@ -615,10 +610,10 @@ def main():
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
torch.save(G.as_dict(), params.lm_dir / "G_3_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
logging.info("Loading pre-compiled G_3_gram.pt")
d = torch.load(params.lm_dir / "G_3_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
@ -662,16 +657,16 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
spgispeech = SPGISpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
dev_cuts = spgispeech.dev_cuts()
val_cuts = spgispeech.val_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
dev_dl = spgispeech.test_dataloaders(dev_cuts)
val_dl = spgispeech.test_dataloaders(val_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
test_sets = ["dev", "val"]
test_dl = [dev_dl, val_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
@ -687,9 +682,7 @@ def main():
eos_id=eos_id,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")

10
egs/spgispeech/ASR/decode.sh Executable file
View File

@ -0,0 +1,10 @@
#!/usr/bin/env bash
set -eou pipefail
. ./path.sh
. parse_options.sh || exit 1
# Train Conformer CTC model
utils/queue-freegpu.pl --gpu 1 --mem 10G -l "hostname=c*" -q g.q conformer_ctc/exp/decode.log \
python conformer_ctc/decode.py --epoch 12 --avg 3 --method ctc-decoding --max-duration 50 --num-paths 20