from local

This commit is contained in:
dohe0342 2023-02-02 19:10:22 +09:00
parent 24ae0d3a63
commit 64b4e778c9
2 changed files with 13 additions and 10 deletions

View File

@ -206,16 +206,19 @@ def decode_one_batch_greedy(
assert HLG is None
decoding_graph = H
if params.method == 'greedy-search' or params.method == 'ctc-decoding':
batch_size = nnet_output.size(0)
for i in range(batch_size):
topk_log_probs, topk_indexes = nnet_output[i].topk(1)
topk_indexes = topk_indexes.squeeze().unique_consecutive()
topk_indexes = topk_indexes[topk_indexes != 0]
hyp = ''
for idx in topk_indexes:
hyp += token_dict[idx.item()]
print(hyp)
hyps = []
batch_size = nnet_output.size(0)
for i in range(batch_size):
topk_log_probs, topk_indexes = nnet_output[i].topk(1)
topk_indexes = topk_indexes.squeeze().unique_consecutive()
topk_indexes = topk_indexes[topk_indexes != 0]
hyp = ''
for idx in topk_indexes:
hyp += token_dict[idx.item()]
hyps.append(hyp)
return hyps
def decode_one_batch(