add distill whisper results (#1648)

This commit is contained in:
Yuekai Zhang 2024-06-13 00:20:04 +08:00 committed by GitHub
parent 13f55d0735
commit d5be739639
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 9 deletions

View File

@ -6,10 +6,11 @@
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | |Model| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| |-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | | | Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | |whisper-large-v2-ft |Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|whisper-large-v2-ft-distill |Greedy Search | 24.91 | 26.73 | 0.91 | 0.94 | 2.71 | 2.98 | 17.65 | 2.81 | 2.47 | 5.16 | 2.10 | 6.27 | 8.34 |
Command for training is: Command for training is:
```bash ```bash

View File

@ -17,12 +17,15 @@
| 7 | aispeech_api_zh | 3.62% | 2023.12 | | 7 | aispeech_api_zh | 3.62% | 2023.12 |
| 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 | | 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 |
| 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 | | 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 |
| 10 | **zipformer (70Mb)** | **6.17%** | 2023.10 | | 10 | **whisper-large-ft-v1-distill** | **4.71%** | 2024.04 |
| 11 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | | 11 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 12 | baidu_pro_api_zh | 7.29% | 2023.12 | | 12 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 13 | baidu_pro_api_zh | 7.29% | 2023.12 |
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67) Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
For **whisper-large-ft-v1-distill**, instead of actually using distillation loss for training, the model structure and parameter initialization method from the [distill-whisper](https://arxiv.org/abs/2311.00430) paper were adopted: only the first and last layers of the decoder were retained.
<details><summary> Detail all models </summary><p> <details><summary> Detail all models </summary><p>
| Model | Training Set | Note | | Model | Training Set | Note |
@ -31,7 +34,7 @@ Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leade
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs| |[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs|
|[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs | |[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs |
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs| |[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs|
|[whisper-large-ft-v1-distill](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1-distill)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 6 epochs|
</details> </details>

View File

@ -58,6 +58,7 @@ from lhotse.cut import Cut
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer from whisper.normalizers import BasicTextNormalizer
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert from zhconv import convert
@ -215,7 +216,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -227,6 +228,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction", help="replace whisper encoder forward method to remove input length restriction",
) )
parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)
return parser return parser
@ -431,6 +439,8 @@ def main():
if params.remove_whisper_encoder_input_length_restriction: if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu") model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0: if params.epoch > 0:
if params.avg > 1: if params.avg > 1:

View File

@ -0,0 +1 @@
../../../multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py