from local

This commit is contained in:
dohe0342 2023-02-25 16:05:21 +09:00
parent ef9c392c5b
commit 3d1ff70da4
3 changed files with 8 additions and 0 deletions

View File

@ -415,6 +415,14 @@ def decode_one_batch(
return {key: hyps}
if params.method == "greedy-search":
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
warmup=warmup,
)
hyps = greedy_search(nnet_output, memory_key_padding_mask)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]