diff --git a/egs/aishell/ASR/transformer_ctc/.decode.py.swp b/egs/aishell/ASR/transformer_ctc/.decode.py.swp index e80ac848b..d53020f65 100644 Binary files a/egs/aishell/ASR/transformer_ctc/.decode.py.swp and b/egs/aishell/ASR/transformer_ctc/.decode.py.swp differ diff --git a/egs/aishell/ASR/transformer_ctc/decode.py b/egs/aishell/ASR/transformer_ctc/decode.py index 4fc8ecb26..f69128454 100755 --- a/egs/aishell/ASR/transformer_ctc/decode.py +++ b/egs/aishell/ASR/transformer_ctc/decode.py @@ -453,6 +453,12 @@ def decode_dataset( num_batches = "?" results = defaultdict(list) + + subs_all = 0 + dels_all = 0 + ins_all = 0 + char_num = 0 + for batch_idx, batch in enumerate(dl): #logging.info(f"decoding {batch_idx} th batch") texts = batch["supervisions"]["text"] @@ -483,9 +489,18 @@ def decode_dataset( ) for i, hyp in enumerate(hyps): - print('hyp = ', hyp) - print('ref = ', texts[i].replace(' ', '')) - print('') + #print('hyp = ', hyp) + #print('ref = ', texts[i].replace(' ', '')) + #print('') + ref = texts[i].replace(' ', '') + [cer, subs, dels, ins] = metrics.get_cer(hyp, ref) + subs_all += subs + dels_all += dels + ins_all += ins + char_num == len(ref) + + cer = (subs_all+dels_all+ins_all) / char_num + print(cer * 100) ''' for lm_scale, hyps in hyps_dict.items():