From fe787d616712d193612644b2be069dad602137cd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 19 Apr 2022 11:41:00 +0800 Subject: [PATCH] Fix decoding warnings. --- .../ASR/transducer_stateless/beam_search.py | 13 +++++++++---- egs/librispeech/ASR/transducer_stateless2/joiner.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 7b4fac31d..388a8d67a 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Dict, List, Optional @@ -505,8 +506,10 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] @@ -613,8 +616,10 @@ def _deprecated_modified_beam_search( topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py index b30b6895c..765f0be8b 100644 --- a/egs/librispeech/ASR/transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless2/joiner.py @@ -62,6 +62,6 @@ class Joiner(nn.Module): # We reuse the beam_search.py from transducer_stateless, # which expects that the joiner network outputs # a 2-D tensor. - logits = logits.unsqueeze(2).unsqueeze(1) + logits = logits.squeeze(2).squeeze(1) return logits