do a change for decode.py (#400)

This commit is contained in:
Mingshuang Luo 2022-06-06 15:44:04 +08:00 committed by GitHub
parent f1abce72f8
commit 0a21eaae7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -63,7 +63,7 @@ import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -256,7 +256,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,