From ab5a9f5ca256fb0ef3cb50ad0b3c4993d14d7a26 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Thu, 2 Feb 2023 19:06:25 +0900 Subject: [PATCH] from local --- .../ASR/transformer_ctc/.decode.py.swp | Bin 49152 -> 49152 bytes egs/aishell/ASR/transformer_ctc/decode.py | 24 ++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/egs/aishell/ASR/transformer_ctc/.decode.py.swp b/egs/aishell/ASR/transformer_ctc/.decode.py.swp index 612c1316e7910b86d667ac97357779c869a07b3c..eb3f1c45c01652f7514c3bf9f1ccb3d1b368f02f 100644 GIT binary patch delta 234 zcmZo@U~Xt&7ELk;^Ym4))H7fJ0s#hwOI^2 AttributeDict: return params +def decode_one_batch_greedy( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + batch: dict, + lexicon: Lexicon, + sos_id: int, + eos_id: int, + token_dict: dict, + ) -> Dict[str, List[List[int]]]: + + if params.method == 'greedy-search' or params.method == 'ctc-decoding': + batch_size = nnet_output.size(0) + for i in range(batch_size): + topk_log_probs, topk_indexes = nnet_output[i].topk(1) + topk_indexes = topk_indexes.squeeze().unique_consecutive() + topk_indexes = topk_indexes[topk_indexes != 0] + hyp = '' + for idx in topk_indexes: + hyp += token_dict[idx.item()] + print(hyp) + + def decode_one_batch( params: AttributeDict, model: nn.Module,