mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Fix decoding.
This commit is contained in:
parent
4f3a53fc41
commit
d4440b421c
@ -62,10 +62,15 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lattice-score-scale",
|
"--lattice-score-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=0.5,
|
||||||
help="The scale to be applied to `lattice.scores`."
|
help="The scale to be applied to `lattice.scores`."
|
||||||
"It's needed if you use any kinds of n-best based rescoring. "
|
"It's needed if you use any kinds of n-best based rescoring. "
|
||||||
"Currently, it is used when the decoding method is: nbest, "
|
"Currently, it is used when the decoding method is: nbest, "
|
||||||
@ -86,7 +91,7 @@ def get_params() -> AttributeDict:
|
|||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"num_decoder_layers": 0,
|
"num_decoder_layers": 6,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"is_espnet_structure": True,
|
"is_espnet_structure": True,
|
||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
@ -110,7 +115,6 @@ def get_params() -> AttributeDict:
|
|||||||
# "method": "nbest-oracle",
|
# "method": "nbest-oracle",
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
# attention-decoder, and nbest-oracle
|
# attention-decoder, and nbest-oracle
|
||||||
"num_paths": 100,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -261,6 +265,8 @@ def decode_one_batch(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
scale=params.lattice_score_scale,
|
scale=params.lattice_score_scale,
|
||||||
|
sos_id=params.sos_id,
|
||||||
|
eos_id=params.eos_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
@ -474,10 +480,15 @@ def main():
|
|||||||
num_decoder_layers=params.num_decoder_layers,
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
is_espnet_structure=params.is_espnet_structure,
|
is_espnet_structure=params.is_espnet_structure,
|
||||||
mmi_loss=params.mmi_loss,
|
is_bpe=False,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert model.decoder_num_class == num_classes + 1
|
||||||
|
|
||||||
|
params.sos_id = num_classes
|
||||||
|
params.eos_id = num_classes
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user