From b2a2b1d387b52b4538704cc8c3f8e0e0e2f69d1c Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Wed, 16 Mar 2022 13:25:35 -0400 Subject: [PATCH] add decoding --- .../ASR/conformer_ctc/asr_datamodule.py | 4 +- egs/spgispeech/ASR/conformer_ctc/decode.py | 55 ++++++++----------- egs/spgispeech/ASR/decode.sh | 10 ++++ 3 files changed, 36 insertions(+), 33 deletions(-) create mode 100755 egs/spgispeech/ASR/decode.sh diff --git a/egs/spgispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/spgispeech/ASR/conformer_ctc/asr_datamodule.py index 8e46293f3..2fc67f4e2 100644 --- a/egs/spgispeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/spgispeech/ASR/conformer_ctc/asr_datamodule.py @@ -104,7 +104,7 @@ class SPGISpeechAsrDataModule: group.add_argument( "--max-duration", type=int, - default=140.0, + default=100.0, help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) @@ -296,7 +296,7 @@ class SPGISpeechAsrDataModule: return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") @lru_cache() - def test_cuts(self) -> CutSet: + def val_cuts(self) -> CutSet: logging.info("About to get SPGISpeech val cuts") return load_manifest_lazy(self.args.manifest_dir / "cuts_val.jsonl.gz") diff --git a/egs/spgispeech/ASR/conformer_ctc/decode.py b/egs/spgispeech/ASR/conformer_ctc/decode.py index 177e33a6e..eb4baf552 100755 --- a/egs/spgispeech/ASR/conformer_ctc/decode.py +++ b/egs/spgispeech/ASR/conformer_ctc/decode.py @@ -23,10 +23,11 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +from numpy import True_ import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import SPGISpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler @@ -131,7 +132,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=str, - default="data/lang_bpe_500", + default="data/lang_bpe_5000", help="The lang dir", ) @@ -140,7 +141,7 @@ def get_parser(): type=str, default="data/lm", help="""The LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt + It should contain either G_3_gram.pt or G_3_gram.fst.txt """, ) @@ -420,7 +421,7 @@ def decode_dataset( G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. + is a 3-gram LM. Returns: Return a dict, whose key may be "no-rescore" if no LM rescoring is used, or it may be "lm_scale_0.7" if LM rescoring is used. @@ -462,9 +463,7 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert ( - len(results) > 0 - ), "It should not decode to empty in the first batch!" + assert len(results) > 0, "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -479,9 +478,7 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -512,9 +509,7 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -534,7 +529,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + SPGISpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -570,7 +565,7 @@ def main(): HLG = None H = k2.ctc_topo( max_token=max_token_id, - modified=False, + modified=True, # Use modified topology since vocab size is large device=device, ) bpe_model = spm.SentencePieceProcessor() @@ -591,10 +586,10 @@ def main(): "whole-lattice-rescoring", "attention-decoder", ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") + if not (params.lm_dir / "G_3_gram.pt").is_file(): + logging.info("Loading G_3_gram.fst.txt") logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: + with open(params.lm_dir / "G_3_gram.fst.txt") as f: first_word_disambig_id = lexicon.word_table["#0"] G = k2.Fsa.from_openfst(f.read(), acceptor=False) @@ -615,10 +610,10 @@ def main(): # for why we need to do this. G.dummy = 1 - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + torch.save(G.as_dict(), params.lm_dir / "G_3_gram.pt") else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + logging.info("Loading pre-compiled G_3_gram.pt") + d = torch.load(params.lm_dir / "G_3_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) if params.method in ["whole-lattice-rescoring", "attention-decoder"]: @@ -662,16 +657,16 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + spgispeech = SPGISpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + dev_cuts = spgispeech.dev_cuts() + val_cuts = spgispeech.val_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_dl = spgispeech.test_dataloaders(dev_cuts) + val_dl = spgispeech.test_dataloaders(val_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["dev", "val"] + test_dl = [dev_dl, val_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -687,9 +682,7 @@ def main(): eos_id=eos_id, ) - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) + save_results(params=params, test_set_name=test_set, results_dict=results_dict) logging.info("Done!") diff --git a/egs/spgispeech/ASR/decode.sh b/egs/spgispeech/ASR/decode.sh new file mode 100755 index 000000000..cf4a5d79f --- /dev/null +++ b/egs/spgispeech/ASR/decode.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +set -eou pipefail + +. ./path.sh +. parse_options.sh || exit 1 + +# Train Conformer CTC model +utils/queue-freegpu.pl --gpu 1 --mem 10G -l "hostname=c*" -q g.q conformer_ctc/exp/decode.log \ + python conformer_ctc/decode.py --epoch 12 --avg 3 --method ctc-decoding --max-duration 50 --num-paths 20