from local

This commit is contained in:
dohe0342 2023-02-25 16:06:14 +09:00
parent a66a1fcae0
commit 9374e5d07e
2 changed files with 3 additions and 2 deletions

View File

@ -415,9 +415,10 @@ def decode_one_batch(
return {key: hyps} return {key: hyps}
if params.method == "greedy-search": if params.method == "greedy-search":
memory, memory_key_padding_mask
att_loss = model.decoder_forward( att_loss = model.decoder_forward(
encoder_memory, memory,
memory_mask, momory_key_padding_mask,
token_ids=unsorted_token_ids, token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id, sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id, eos_id=graph_compiler.eos_id,