diff --git a/egs/aishell/ASR/transformer_ctc/.decode.py.swp b/egs/aishell/ASR/transformer_ctc/.decode.py.swp index 612c1316e..eb3f1c45c 100644 Binary files a/egs/aishell/ASR/transformer_ctc/.decode.py.swp and b/egs/aishell/ASR/transformer_ctc/.decode.py.swp differ diff --git a/egs/aishell/ASR/transformer_ctc/decode.py b/egs/aishell/ASR/transformer_ctc/decode.py index bc0fc279f..ce76c2fec 100755 --- a/egs/aishell/ASR/transformer_ctc/decode.py +++ b/egs/aishell/ASR/transformer_ctc/decode.py @@ -163,6 +163,30 @@ def get_params() -> AttributeDict: return params +def decode_one_batch_greedy( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + batch: dict, + lexicon: Lexicon, + sos_id: int, + eos_id: int, + token_dict: dict, + ) -> Dict[str, List[List[int]]]: + + if params.method == 'greedy-search' or params.method == 'ctc-decoding': + batch_size = nnet_output.size(0) + for i in range(batch_size): + topk_log_probs, topk_indexes = nnet_output[i].topk(1) + topk_indexes = topk_indexes.squeeze().unique_consecutive() + topk_indexes = topk_indexes[topk_indexes != 0] + hyp = '' + for idx in topk_indexes: + hyp += token_dict[idx.item()] + print(hyp) + + def decode_one_batch( params: AttributeDict, model: nn.Module,