from local

This commit is contained in:
dohe0342 2023-02-02 19:10:22 +09:00
parent 24ae0d3a63
commit 64b4e778c9
2 changed files with 13 additions and 10 deletions

View File

@ -206,7 +206,8 @@ def decode_one_batch_greedy(
assert HLG is None assert HLG is None
decoding_graph = H decoding_graph = H
if params.method == 'greedy-search' or params.method == 'ctc-decoding': hyps = []
batch_size = nnet_output.size(0) batch_size = nnet_output.size(0)
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = nnet_output[i].topk(1) topk_log_probs, topk_indexes = nnet_output[i].topk(1)
@ -215,7 +216,9 @@ def decode_one_batch_greedy(
hyp = '' hyp = ''
for idx in topk_indexes: for idx in topk_indexes:
hyp += token_dict[idx.item()] hyp += token_dict[idx.item()]
print(hyp) hyps.append(hyp)
return hyps
def decode_one_batch( def decode_one_batch(