mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Fix warnings.
This commit is contained in:
parent
d79f5fecf7
commit
9105e3871e
@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -503,8 +504,10 @@ def modified_beam_search(
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
@ -614,8 +617,10 @@ def _deprecated_modified_beam_search(
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user