From ce9b23327f3f90d8f66a27d8b25b3bb7e8052882 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 9 Sep 2021 15:15:35 +0800 Subject: [PATCH] Fix decode.py --- .../ASR/conformer_mmi_phone/decode.py | 38 ++++++------------- pyproject.toml | 11 ++++-- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/conformer_mmi_phone/decode.py b/egs/librispeech/ASR/conformer_mmi_phone/decode.py index 7e83f736d..e8b9537a4 100755 --- a/egs/librispeech/ASR/conformer_mmi_phone/decode.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/decode.py @@ -16,7 +16,6 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.decode import ( get_lattice, @@ -28,6 +27,7 @@ from icefall.decode import ( rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, get_texts, @@ -58,6 +58,11 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--method", + type=str, + ) + parser.add_argument( "--lattice-score-scale", type=float, @@ -82,7 +87,7 @@ def get_params() -> AttributeDict: "nhead": 8, "attention_dim": 512, "subsampling_factor": 4, - "num_decoder_layers": 6, + "num_decoder_layers": 0, "vgg_frontend": False, "is_espnet_structure": True, "mmi_loss": False, @@ -102,7 +107,7 @@ def get_params() -> AttributeDict: # "method": "nbest", # "method": "nbest-rescoring", # "method": "whole-lattice-rescoring", - "method": "attention-decoder", + # "method": "attention-decoder", # "method": "nbest-oracle", # num_paths is used when method is "nbest", "nbest-rescoring", # attention-decoder, and nbest-oracle @@ -118,8 +123,6 @@ def decode_one_batch( HLG: k2.Fsa, batch: dict, lexicon: Lexicon, - sos_id: int, - eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -153,10 +156,6 @@ def decode_one_batch( for the format of the `batch`. lexicon: It contains word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -234,7 +233,8 @@ def decode_one_batch( "attention-decoder", ] - lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if params.method == "nbest-rescoring": @@ -261,8 +261,6 @@ def decode_one_batch( model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, scale=params.lattice_score_scale, ) else: @@ -282,8 +280,6 @@ def decode_dataset( model: nn.Module, HLG: k2.Fsa, lexicon: Lexicon, - sos_id: int, - eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[int], List[int]]]]: """Decode dataset. @@ -299,10 +295,6 @@ def decode_dataset( The decoding graph. lexicon: It contains word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -334,8 +326,6 @@ def decode_dataset( batch=batch, lexicon=lexicon, G=G, - sos_id=sos_id, - eos_id=eos_id, ) for lm_scale, hyps in hyps_dict.items(): @@ -427,14 +417,10 @@ def main(): logging.info(f"device: {device}") - graph_compiler = BpeCtcTrainingGraphCompiler( + graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, device=device, - sos_token="", - eos_token="", ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id HLG = k2.Fsa.from_dict( torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") @@ -530,8 +516,6 @@ def main(): HLG=HLG, lexicon=lexicon, G=G, - sos_id=sos_id, - eos_id=eos_id, ) save_results( diff --git a/pyproject.toml b/pyproject.toml index 0d80ed4d2..b4bd3a798 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,11 @@ profile = "black" [tool.black] line-length = 80 exclude = ''' -/( - \.git - | \.github -)/ +( + /( + \.git + | \.github + | icefall/shared/* + )/ +) '''