mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update Zipformer-xl 700M Results on multi-hans-zh (#1694)
* add blank penalty * update zipformer-xl results * fix typo
This commit is contained in:
parent
11151415f3
commit
4af81af5a6
@ -43,6 +43,66 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result
|
|||||||
are available at
|
are available at
|
||||||
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
||||||
|
|
||||||
|
### Multi Chinese datasets char-based training results (streaming) on zipformer-xl model
|
||||||
|
|
||||||
|
#### Streaming (with CTC head)
|
||||||
|
|
||||||
|
The training command for extra-large model (num of params : ~700M):
|
||||||
|
|
||||||
|
Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features.
|
||||||
|
|
||||||
|
```
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--max-duration 1200 \
|
||||||
|
--num-workers 8 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--exp-dir zipformer/exp-xl \
|
||||||
|
--causal 1 \
|
||||||
|
--num-encoder-layers 2,3,5,6,5,3 \
|
||||||
|
--feedforward-dim 1536,2048,3072,4096,3072,1536 \
|
||||||
|
--encoder-dim 512,768,1024,1536,1024,512 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||||
|
--decoder-dim 768 --joiner-dim 768 \
|
||||||
|
--value-head-dim 18 \
|
||||||
|
--query-head-dim 48 \
|
||||||
|
--num-heads 4,4,4,8,4,4
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
The decoding command for transducer greedy search:
|
||||||
|
|
||||||
|
```
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--causal 1 \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--chunk_size -1
|
||||||
|
--left-context-frames -1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--exp-dir zipformer/exp-xl \
|
||||||
|
--max-duration 1200 \
|
||||||
|
--num-encoder-layers 2,3,5,6,5,3 \
|
||||||
|
--feedforward-dim 1536,2048,3072,4096,3072,1536 \
|
||||||
|
--encoder-dim 512,768,1024,1536,1024,512 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||||
|
--decoder-dim 768 --joiner-dim 768 \
|
||||||
|
--value-head-dim 18 \
|
||||||
|
--query-head-dim 48 \
|
||||||
|
--num-heads 4,4,4,8,4,4
|
||||||
|
```
|
||||||
|
|
||||||
|
Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
|
||||||
|
|
||||||
|
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||||
|
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||||
|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||||
|
| Transducer Greedy Offline | 21.67 | 23.43 | 1.22 | 1.31 | 3.17 | 3.27 | 14.64 | 2.42 | 1.99 | 5.00 | 2.29 | 5.98 | 5.15 | 5.85 | 6.89 |
|
||||||
|
|
||||||
|
Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-xl
|
||||||
### Multi Chinese datasets char-based training results (streaming) on zipformer large model
|
### Multi Chinese datasets char-based training results (streaming) on zipformer large model
|
||||||
|
|
||||||
#### Streaming (with CTC head)
|
#### Streaming (with CTC head)
|
||||||
@ -64,6 +124,7 @@ Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech
|
|||||||
--num-encoder-layers 2,2,4,5,4,2 \
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
--feedforward-dim 768,1024,1536,2048,1536,768 \
|
--feedforward-dim 768,1024,1536,2048,1536,768 \
|
||||||
--encoder-dim 256,384,512,768,512,256 \
|
--encoder-dim 256,384,512,768,512,256 \
|
||||||
|
--blank-penalty 0.7 \
|
||||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -303,6 +303,17 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -431,6 +442,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -455,6 +467,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
blank_penalty=params.blank_penalty,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
@ -468,8 +481,9 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
|
key = f"blank_penalty_{params.blank_penalty}"
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search_" + key: hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
key = f"beam_{params.beam}_"
|
key = f"beam_{params.beam}_"
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
@ -657,6 +671,7 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user