mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Merge branch 'cr-ctc-aishell' of gitee.com:Mistmoon/icefall into cr-ctc-aishell
This commit is contained in:
commit
75e2daf6a9
@ -955,6 +955,7 @@ See <https://github.com/k2-fsa/icefall/pull/1980> for more details.
|
||||
| decoding method | test | dev | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 |
|
||||
| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 |
|
||||
|
||||
The training command using 2 32G-V100 GPUs is:
|
||||
```bash
|
||||
@ -982,7 +983,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-greedy-search; do
|
||||
for m in ctc-greedy-search ctc-prefix-beam-search; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 60 \
|
||||
--avg 28 \
|
||||
|
@ -32,6 +32,16 @@ Usage:
|
||||
--use-transducer 0 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-greedy-search
|
||||
(2) ctc-prefix-beam-search (with cr-ctc)
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 60 \
|
||||
--avg 21 \
|
||||
--exp-dir zipformer/exp \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-prefix-beam-search
|
||||
"""
|
||||
|
||||
|
||||
@ -156,6 +166,22 @@ def get_parser():
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_decoding_params() -> AttributeDict:
|
||||
"""Parameters for decoding."""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20, # for k2 fsa composition
|
||||
"output_beam": 8, # for k2 fsa composition
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"beam": 4, # for prefix-beam-search
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -236,11 +262,10 @@ def decode_one_batch(
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
return {"ctc-greedy-search_" + key: hyps}
|
||||
return {"ctc-greedy-search" : hyps}
|
||||
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||
return {"ctc-prefix-beam-search_" + key: hyps}
|
||||
return {"ctc-prefix-beam-search" : hyps}
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
||||
@ -361,7 +386,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
# params.update(get_decoding_params())
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
|
Loading…
x
Reference in New Issue
Block a user