diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 2e9bf3e0b..86e34be61 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Dict, List, Optional @@ -503,8 +504,10 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] @@ -614,8 +617,10 @@ def _deprecated_modified_beam_search( topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]]