Update decode.py (#392)

* Update decode.py

fix bug ```TypeError: greedy_search_batch() missing 1 required positional argument: 'encoder_out_lens'```

* fix modified_beam_search

Co-authored-by: fanlu3 <fanlu@jd.com>
This commit is contained in:
fanlu 2022-06-04 19:08:17 +08:00 committed by GitHub
parent 148f69d8d9
commit 8a3068ead8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -274,6 +274,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -282,6 +283,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])