From 1e0a6edd28f2a488f4380f1d5c5f4145e50c9594 Mon Sep 17 00:00:00 2001 From: hhzzff <2070620600@qq.com> Date: Mon, 7 Jul 2025 17:02:36 +0800 Subject: [PATCH] update experiments related to ctc-prefix-beam-search --- egs/aishell/ASR/RESULTS.md | 3 ++- egs/aishell/ASR/zipformer/ctc_decode.py | 33 ++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 24c424b47..64017cc2b 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -955,6 +955,7 @@ See for more details. | decoding method | test | dev | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 | +| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 | The training command using 2 32G-V100 GPUs is: ```bash @@ -982,7 +983,7 @@ export CUDA_VISIBLE_DEVICES="0,1" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 60 \ --avg 28 \ diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py index 8810bac9a..c05982a03 100755 --- a/egs/aishell/ASR/zipformer/ctc_decode.py +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -32,6 +32,16 @@ Usage: --use-transducer 0 \ --max-duration 600 \ --decoding-method ctc-greedy-search +(2) ctc-prefix-beam-search (with cr-ctc) +./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 21 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method ctc-prefix-beam-search """ @@ -156,6 +166,22 @@ def get_parser(): return parser + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "beam": 4, # for prefix-beam-search + } + ) + return params + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -236,11 +262,10 @@ def decode_one_batch( for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "ctc-greedy-search": - return {"ctc-greedy-search_" + key: hyps} + return {"ctc-greedy-search" : hyps} elif params.decoding_method == "ctc-prefix-beam-search": - return {"ctc-prefix-beam-search_" + key: hyps} + return {"ctc-prefix-beam-search" : hyps} else: assert False, f"Unsupported decoding method: {params.decoding_method}" @@ -361,7 +386,7 @@ def main(): params = get_params() # add decoding params - # params.update(get_decoding_params()) + params.update(get_decoding_params()) params.update(vars(args)) assert params.decoding_method in (