Fix decoding warnings.

This commit is contained in:
Fangjun Kuang 2022-04-19 11:41:00 +08:00
parent ec9bbf7352
commit fe787d6167
2 changed files with 10 additions and 5 deletions

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -505,8 +506,10 @@ def modified_beam_search(
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist() with warnings.catch_warnings():
topk_token_indexes = (topk_indexes % vocab_size).tolist() warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)): for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k] hyp_idx = topk_hyp_indexes[k]
@ -613,8 +616,10 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1) topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist() with warnings.catch_warnings():
topk_token_indexes = topk_token_indexes.tolist() warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)): for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]

View File

@ -62,6 +62,6 @@ class Joiner(nn.Module):
# We reuse the beam_search.py from transducer_stateless, # We reuse the beam_search.py from transducer_stateless,
# which expects that the joiner network outputs # which expects that the joiner network outputs
# a 2-D tensor. # a 2-D tensor.
logits = logits.unsqueeze(2).unsqueeze(1) logits = logits.squeeze(2).squeeze(1)
return logits return logits