mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Fix decoding warnings.
This commit is contained in:
parent
ec9bbf7352
commit
fe787d6167
@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
@ -505,8 +506,10 @@ def modified_beam_search(
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
with warnings.catch_warnings():
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
for k in range(len(topk_hyp_indexes)):
|
for k in range(len(topk_hyp_indexes)):
|
||||||
hyp_idx = topk_hyp_indexes[k]
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
@ -613,8 +616,10 @@ def _deprecated_modified_beam_search(
|
|||||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
with warnings.catch_warnings():
|
||||||
topk_token_indexes = topk_token_indexes.tolist()
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
for i in range(len(topk_hyp_indexes)):
|
for i in range(len(topk_hyp_indexes)):
|
||||||
hyp = A[topk_hyp_indexes[i]]
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
|
@ -62,6 +62,6 @@ class Joiner(nn.Module):
|
|||||||
# We reuse the beam_search.py from transducer_stateless,
|
# We reuse the beam_search.py from transducer_stateless,
|
||||||
# which expects that the joiner network outputs
|
# which expects that the joiner network outputs
|
||||||
# a 2-D tensor.
|
# a 2-D tensor.
|
||||||
logits = logits.unsqueeze(2).unsqueeze(1)
|
logits = logits.squeeze(2).squeeze(1)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
Loading…
x
Reference in New Issue
Block a user