from local

This commit is contained in:
dohe0342 2023-02-02 19:06:25 +09:00
parent 1d3ae3a0e6
commit ab5a9f5ca2
2 changed files with 24 additions and 0 deletions

View File

@ -163,6 +163,30 @@ def get_params() -> AttributeDict:
return params return params
def decode_one_batch_greedy(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
batch: dict,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
token_dict: dict,
) -> Dict[str, List[List[int]]]:
if params.method == 'greedy-search' or params.method == 'ctc-decoding':
batch_size = nnet_output.size(0)
for i in range(batch_size):
topk_log_probs, topk_indexes = nnet_output[i].topk(1)
topk_indexes = topk_indexes.squeeze().unique_consecutive()
topk_indexes = topk_indexes[topk_indexes != 0]
hyp = ''
for idx in topk_indexes:
hyp += token_dict[idx.item()]
print(hyp)
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,