Merge branch 'cr-ctc-aishell' of gitee.com:Mistmoon/icefall into cr-ctc-aishell

This commit is contained in:
hhzzff 2025-07-07 17:03:07 +08:00
commit 75e2daf6a9
2 changed files with 31 additions and 5 deletions

View File

@ -955,6 +955,7 @@ See <https://github.com/k2-fsa/icefall/pull/1980> for more details.
| decoding method | test | dev | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 |
| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 |
The training command using 2 32G-V100 GPUs is:
```bash
@ -982,7 +983,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 60 \
--avg 28 \

View File

@ -32,6 +32,16 @@ Usage:
--use-transducer 0 \
--max-duration 600 \
--decoding-method ctc-greedy-search
(2) ctc-prefix-beam-search (with cr-ctc)
./zipformer/ctc_decode.py \
--epoch 60 \
--avg 21 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--max-duration 600 \
--decoding-method ctc-prefix-beam-search
"""
@ -156,6 +166,22 @@ def get_parser():
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20, # for k2 fsa composition
"output_beam": 8, # for k2 fsa composition
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"beam": 4, # for prefix-beam-search
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -236,11 +262,10 @@ def decode_one_batch(
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "ctc-greedy-search":
return {"ctc-greedy-search_" + key: hyps}
return {"ctc-greedy-search" : hyps}
elif params.decoding_method == "ctc-prefix-beam-search":
return {"ctc-prefix-beam-search_" + key: hyps}
return {"ctc-prefix-beam-search" : hyps}
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
@ -361,7 +386,7 @@ def main():
params = get_params()
# add decoding params
# params.update(get_decoding_params())
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (