mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 07:34:21 +00:00
Update beam_search.py
This commit is contained in:
parent
8dcf9644cb
commit
1b854e5c44
@ -486,7 +486,9 @@ def modified_beam_search(
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
|
topk_hyp_indexes = torch.div(
|
||||||
|
topk_indexes, vocab_size, rounding_mode="trunc"
|
||||||
|
)
|
||||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user