diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py index 0ae001d3f..3a08b100d 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -486,7 +486,9 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc") + topk_hyp_indexes = torch.div( + topk_indexes, vocab_size, rounding_mode="trunc" + ) topk_hyp_indexes = topk_hyp_indexes.tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist()