from local

This commit is contained in:
dohe0342 2023-02-25 16:10:40 +09:00
parent ca5e95929f
commit ee845f219c
2 changed files with 1 additions and 1 deletions

View File

@ -417,7 +417,7 @@ def decode_one_batch(
if params.method == "greedy-search": if params.method == "greedy-search":
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
att_loss = model.decoder_forward( pred, att_loss = model.decoder_forward(
memory, memory,
memory_key_padding_mask, memory_key_padding_mask,
token_ids=unsorted_token_ids, token_ids=unsorted_token_ids,