diff --git a/egs/librispeech/ASR/conformer_mmi_phone/decode.py b/egs/librispeech/ASR/conformer_mmi_phone/decode.py index 7e9c8f78e..6f485ba57 100755 --- a/egs/librispeech/ASR/conformer_mmi_phone/decode.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/decode.py @@ -62,10 +62,15 @@ def get_parser(): type=str, ) + parser.add_argument( + "--num-paths", + type=int, + ) + parser.add_argument( "--lattice-score-scale", type=float, - default=1.0, + default=0.5, help="The scale to be applied to `lattice.scores`." "It's needed if you use any kinds of n-best based rescoring. " "Currently, it is used when the decoding method is: nbest, " @@ -86,7 +91,7 @@ def get_params() -> AttributeDict: "nhead": 8, "attention_dim": 512, "subsampling_factor": 4, - "num_decoder_layers": 0, + "num_decoder_layers": 6, "vgg_frontend": False, "is_espnet_structure": True, "mmi_loss": False, @@ -110,7 +115,6 @@ def get_params() -> AttributeDict: # "method": "nbest-oracle", # num_paths is used when method is "nbest", "nbest-rescoring", # attention-decoder, and nbest-oracle - "num_paths": 100, } ) return params @@ -261,6 +265,8 @@ def decode_one_batch( memory=memory, memory_key_padding_mask=memory_key_padding_mask, scale=params.lattice_score_scale, + sos_id=params.sos_id, + eos_id=params.eos_id, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -474,10 +480,15 @@ def main(): num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, is_espnet_structure=params.is_espnet_structure, - mmi_loss=params.mmi_loss, + is_bpe=False, use_feat_batchnorm=params.use_feat_batchnorm, ) + assert model.decoder_num_class == num_classes + 1 + + params.sos_id = num_classes + params.eos_id = num_classes + if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: