Update beam_search.py

This commit is contained in:
Mingshuang Luo 2022-04-11 21:14:02 +08:00 committed by GitHub
parent 8dcf9644cb
commit 1b854e5c44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -486,7 +486,9 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
topk_hyp_indexes = torch.div(
topk_indexes, vocab_size, rounding_mode="trunc"
)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()