Fix decoding.

This commit is contained in:
Fangjun Kuang 2021-09-13 08:06:28 +08:00
parent 4f3a53fc41
commit d4440b421c

View File

@ -62,10 +62,15 @@ def get_parser():
type=str, type=str,
) )
parser.add_argument(
"--num-paths",
type=int,
)
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--lattice-score-scale",
type=float, type=float,
default=1.0, default=0.5,
help="The scale to be applied to `lattice.scores`." help="The scale to be applied to `lattice.scores`."
"It's needed if you use any kinds of n-best based rescoring. " "It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, " "Currently, it is used when the decoding method is: nbest, "
@ -86,7 +91,7 @@ def get_params() -> AttributeDict:
"nhead": 8, "nhead": 8,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"num_decoder_layers": 0, "num_decoder_layers": 6,
"vgg_frontend": False, "vgg_frontend": False,
"is_espnet_structure": True, "is_espnet_structure": True,
"mmi_loss": False, "mmi_loss": False,
@ -110,7 +115,6 @@ def get_params() -> AttributeDict:
# "method": "nbest-oracle", # "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring", # num_paths is used when method is "nbest", "nbest-rescoring",
# attention-decoder, and nbest-oracle # attention-decoder, and nbest-oracle
"num_paths": 100,
} }
) )
return params return params
@ -261,6 +265,8 @@ def decode_one_batch(
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
scale=params.lattice_score_scale, scale=params.lattice_score_scale,
sos_id=params.sos_id,
eos_id=params.eos_id,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -474,10 +480,15 @@ def main():
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure, is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss, is_bpe=False,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
assert model.decoder_num_class == num_classes + 1
params.sos_id = num_classes
params.eos_id = num_classes
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else: