update speechio whisper ft results

This commit is contained in:
Yuekai Zhang 2024-04-24 18:57:34 +08:00
parent df36f93bd8
commit b970ba569a
13 changed files with 1629 additions and 213 deletions

View File

@ -1,5 +1,48 @@
## Results
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
#### Whisper
[./whisper](./whisper)
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 23.45 | 25.42 | 0.78 | 0.83 | 2.75 | 2.93 | 17.11 | 2.68 | 2.33 | 4.97 | 2.02 | 6.34 | 5.06 | 8.38 | 6.94 |
Command for training is:
```bash
pip install -r whisper/requirements.txt
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
```
Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
ln -s icefall_asr_multi-hans-zh_whisper/v1/epoch-2-avg4.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.

View File

@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare WenetSpeech"
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
@ -299,15 +299,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Compute KeSpeech fbank for test/dev"
./local/compute_fbank_kespeech_dev_test.py
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
fi
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
fi
touch data/fbank/.kespeech.done
fi
fi

48
egs/multi_zh-hans/ASR/whisper/decode.py Normal file → Executable file
View File

@ -58,6 +58,7 @@ from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from zhconv import convert
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
@ -214,7 +215,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use.
""",
)
@ -226,6 +227,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)
parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)
return parser
@ -289,7 +297,6 @@ def decode_one_batch(
print(hyps)
return {"beam-search": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
@ -307,6 +314,40 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = []
num_cuts = 0
@ -331,6 +372,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
@ -430,6 +472,8 @@ def main():
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:

93
egs/multi_zh-hans/ASR/whisper/multi_dataset.py Normal file → Executable file
View File

@ -43,7 +43,7 @@ class MultiDataset:
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz
- wenetspeech/cuts_L_fixed.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)
@ -105,7 +105,7 @@ class MultiDataset:
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
)
# KeSpeech
@ -124,10 +124,10 @@ class MultiDataset:
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
alimeeting_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
@ -138,10 +138,10 @@ class MultiDataset:
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
@ -151,55 +151,13 @@ class MultiDataset:
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)
# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)
# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)
# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)
# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)
return wenetspeech_dev_cuts
# return [
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]
def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
@ -267,30 +225,23 @@ class MultiDataset:
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)
return {
"aishell-2_test": aishell2_test_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata_test": magicdata_test_cuts,
"kespeech-asr_test": kespeech_test_cuts,
}
# return {
# "alimeeting_test": alimeeting_test_cuts,
# "alimeeting_eval": alimeeting_eval_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "aishell-4": aishell4_test_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
# }
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "ali-meeting_test": alimeeting_test_cuts,
# "ali-meeting_eval": alimeeting_eval_cuts,
# "aishell-4_test": aishell4_test_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
}

51
egs/multi_zh-hans/ASR/whisper/train.py Normal file → Executable file
View File

@ -66,6 +66,7 @@ from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -146,7 +147,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use.
""",
)
@ -232,6 +233,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)
parser = deepspeed.add_config_arguments(parser)
return parser
@ -441,6 +449,41 @@ def compute_loss(
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
return torch.stack([tensor for tensor in padded_tensors], dim=0)
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
@ -459,7 +502,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [text.replace(" ", "") for text in texts]
texts = [normalize_text_alimeeting(text) for text in texts]
text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps)
@ -759,6 +802,8 @@ def run(rank, world_size, args):
logging.info("About to create model")
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads
@ -824,7 +869,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -0,0 +1,47 @@
import torch
import torch.nn.functional as F
import whisper
from torch import Tensor
from torch import nn
from typing import Dict, Iterable, Optional
from whisper.model import ResidualAttentionBlock, LayerNorm
import numpy as np
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = (
x
+ self.positional_embedding[offset : offset + x.shape[1]]
)
x = x.to(xa.dtype)
# for block in self.blocks:
# x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
# use architecture from the distill whisper model
# see https://github.com/huggingface/distil-whisper
x = self.blocks[0](x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.blocks[-1](x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
def replace_whisper_decoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
"""
whisper.model.TextDecoder.forward = forward

View File

@ -2,41 +2,70 @@
### SpeechIO Test Set Decoding Results
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
#### **Unlocked** SpeechIO test sets (ZH00001 ~ ZH00026)
| Rank 排名 | Model 模型 | CER 字错误率 | Date 时间 |
| --- | --- | --- | --- |
| 1 | ximalaya_api_zh | 1.72% | 2023.12 |
| 2 | aliyun_ftasr_api_zh | 1.85% | 2023.12 |
| 3 | microsoft_batch_zh | 2.40% | 2023.12 |
| 4 | bilibili_api_zh | 2.90% | 2023.09 |
| 5 | tencent_api_zh | 3.18% | 2023.12 |
| 6 | iflytek_lfasr_api_zh | 3.32% | 2023.12 |
| 7 | aispeech_api_zh | 3.62% | 2023.12 |
| 8 | **whisper-large-ft-v1** | **4.45%** | 2024.04 |
| 9 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
<details><summary> Detail all models </summary><p>
| Model | Training Set | Note |
|----------------------------------------------------------------------------------------------------------|---------------|-----------------------------------------------------|
|[zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24)| multi-hans-zh | decoding with transducer head and blank penalty 2.0 |
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper)| wenetspeech | |
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py)|
</details>
<details><summary> Detail all results (字错误率 CER %) </summary><p>
| Test Set ID | 测试场景&内容领域 | bilibili_api_zh (2023.09) | whisper-large-ft-v0 | whisper-large-ft-v1 | zipformer |
|----------------------|-------------------------------|-----------------|---------|-----------|-----------|
| Avg (01-26) | | 2.9 | 6.34 | 4.45 | 6.17 |
| SPEECHIO_ASR_ZH00001 | 新闻联播 | 0.54 | 1.42 | 1.11 | 1.37 |
| SPEECHIO_ASR_ZH00002 | 访谈 鲁豫有约 | 2.78 | 4.76 | 3.25 | 4.67 |
| SPEECHIO_ASR_ZH00003 | 电视节目 天下足球 | 0.81 | 2.17 | 1.64 | 2.71 |
| SPEECHIO_ASR_ZH00004 | 场馆演讲 罗振宇跨年 | 1.48 | 2.53 | 1.91 | 2.54 |
| SPEECHIO_ASR_ZH00005 | 在线教育 李永乐 科普 | 1.47 | 4.27 | 1.98 | 3.12 |
| SPEECHIO_ASR_ZH00006 | 直播 王者荣耀 张大仙&骚白 | 5.85 | 12.55 | 9.28 | 12.86 |
| SPEECHIO_ASR_ZH00007 | 直播 带货 李佳琪&薇娅 | 6.19 | 13.38 | 11.14 | 14.58 |
| SPEECHIO_ASR_ZH00008 | 线下培训 老罗语录 | 3.68 | 9.56 | 7.47 | 9.05 |
| SPEECHIO_ASR_ZH00009 | 播客 故事FM | 3.18 | 5.66 | 3.78 | 5.4 |
| SPEECHIO_ASR_ZH00010 | 播客 创业内幕 | 3.51 | 7.84 | 4.32 | 6.4 |
| SPEECHIO_ASR_ZH00011 | 在线教育 罗翔 刑法法考 | 1.77 | 3.22 | 2.49 | 3.12 |
| SPEECHIO_ASR_ZH00012 | 在线教育 张雪峰 考研 | 2.11 | 5.98 | 3.09 | 4.41 |
| SPEECHIO_ASR_ZH00013 | 短视频 影剪 谷阿莫&牛叔说电影 | 2.97 | 5.91 | 3.78 | 6.56 |
| SPEECHIO_ASR_ZH00014 | 短视频 美式&烹饪 | 3.56 | 6.03 | 5.28 | 8.14 |
| SPEECHIO_ASR_ZH00015 | 评书 单田芳 白眉大侠 | 4.72 | 8.77 | 7.97 | 9.1 |
| SPEECHIO_ASR_ZH00016 | 相声 德云社专场 | 3.01 | 5.24 | 4.41 | 5.59 |
| SPEECHIO_ASR_ZH00017 | 脱口秀 吐槽大会 | 2.93 | 7.05 | 3.27 | 5.17 |
| SPEECHIO_ASR_ZH00018 | 少儿卡通 小猪佩奇&熊出没 | 1.98 | 3.53 | 3.21 | 4.15 |
| SPEECHIO_ASR_ZH00019 | 体育赛事解说 NBA比赛 | 2.32 | 6.89 | 4.24 | 6.66 |
| SPEECHIO_ASR_ZH00020 | 纪录片 篮球人物 | 1.51 | 4.16 | 2.96 | 4.2 |
| SPEECHIO_ASR_ZH00021 | 短视频 汽车之家 汽车评测 | 1.75 | 4.77 | 2.77 | 4.17 |
| SPEECHIO_ASR_ZH00022 | 短视频 小艾大叔 豪宅带看 | 3.29 | 6.35 | 5.66 | 6.72 |
| SPEECHIO_ASR_ZH00023 | 短视频 开箱视频 Zeal&无聊开箱 | 2.18 | 8.99 | 4.45 | 7.94 |
| SPEECHIO_ASR_ZH00024 | 短视频 付老师 农业种植 | 4.80 | 10.81 | 6.25 | 8.64 |
| SPEECHIO_ASR_ZH00025 | 线下课堂 石国鹏 古希腊哲学 | 3.32 | 8.41 | 5.8 | 8.54 |
| SPEECHIO_ASR_ZH00026 | 广播电台节目 张震鬼故事 | 3.70 | 4.52 | 4.11 | 4.67 |
</details>
Command for decoding using fine-tuned whisper:
```bash
@ -76,16 +105,6 @@ mv words.txt ./data/lang_bpe_2000/
--manifest-dir data/fbank_kaldi \
--decoding-method greedy_search
```
Command for fusion the above decoding results from whisper and zipformer:
```bash
python local/whisper_zipformer_fusion.py \
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
--output-log-dir ./results_fusion
```
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
SpeechIO fbank features, decoding scripts, logs, and decoding results
are available at

View File

@ -21,7 +21,7 @@ Since whisper model is more likely to make deletion errors and zipformer model i
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
Usage:
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
"""
import argparse
@ -31,38 +31,43 @@ from typing import Dict, List, Optional, Tuple
import kaldialign
from icefall.utils import store_transcripts, write_error_stats
from speechio_norm import TextNorm
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--whisper-log-dir",
"--model-log-dir",
type=str,
default="./recogs_whisper",
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
)
parser.add_argument(
"--zipformer-log-dir",
type=str,
default="./recogs_zipformer",
help="The directory to store the zipformer logs",
)
parser.add_argument(
"--output-log-dir",
type=str,
default="./results_fusion",
help="The directory to store the fusion logs",
default="./results_whisper_norm",
help="The directory to store the normalized whisper logs",
)
return parser
def save_results(
def save_results_with_speechio_text_norm(
res_dir: Path,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
normalizer = TextNorm()
# normlize items in results_dict
for key, results in results_dict.items():
results_norm = []
for item in results:
wav_name, ref, hyp = item
ref = normalizer(ref)
hyp = normalizer(hyp)
results_norm.append((wav_name, ref, hyp))
results_dict[key] = results_norm
test_set_wers = dict()
suffix = "epoch-999-avg-1"
@ -120,11 +125,9 @@ def extract_hyp_ref_wavname(filename):
return hyps, refs, wav_name
def get_pair_filenames(
def get_filenames(
whisper_log_dir,
zipformer_log_dir,
whisper_suffix="beam-search-epoch-999-avg-1",
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
):
results = []
start_index, end_index = 0, 26
@ -134,80 +137,23 @@ def get_pair_filenames(
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
for partition in dataset_parts:
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
zipformer_filename = (
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
)
results.append((whisper_filename, zipformer_filename))
results.append(whisper_filename)
return results
def fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs, ERR="*"
):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use left
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use right
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
elif a[1] == ERR:
pass
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def main():
parser = get_parser()
args = parser.parse_args()
# mkdir output_log_dir
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
for pair in pair_logs:
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
filenames = get_filenames(args.model_log_dir)
for filename in filenames:
hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
partition_name = filename.split("/")[-1].split("-")[1]
hyps_fusion = fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs
)
partition_name = pair[0].split("/")[-1].split("-")[1]
save_results(
save_results_with_speechio_text_norm(
Path(args.output_log_dir),
partition_name,
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
{"norm": list(zip(wav_name, refs, hyps))},
)
print(f"Processed {partition_name}")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2024 author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import argparse
from lhotse import CutSet, load_manifest_lazy
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--fixed-transcript-path",
type=str,
default="data/fbank/text.fix",
help="""
See https://github.com/wenet-e2e/WenetSpeech/discussions/54
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix
""",
)
parser.add_argument(
"--manifest-dir",
type=str,
default="data/fbank/",
help="Directory to store the manifest files",
)
parser.add_argument(
"--training-subset",
type=str,
default="L",
help="The training subset for wenetspeech.",
)
return parser
def load_fixed_text(fixed_text_path):
"""
fixed text format
X0000016287_92761015_S00001 我是徐涛
X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯
load into a dict
"""
fixed_text_dict = {}
with open(fixed_text_path, 'r') as f:
for line in f:
cut_id, text = line.strip().split(' ', 1)
fixed_text_dict[cut_id] = text
return fixed_text_dict
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
fixed_item = 0
for i, cut in enumerate(manifest):
if i % 10000 == 0:
logging.info(f'Processing cut {i}, fixed {fixed_item}')
cut_id_orgin = cut.id
if cut_id_orgin.endswith('_sp0.9'):
cut_id = cut_id_orgin[:-6]
elif cut_id_orgin.endswith('_sp1.1'):
cut_id = cut_id_orgin[:-6]
else:
cut_id = cut_id_orgin
if cut_id in fixed_text_dict:
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
cut.supervisions[0].text = fixed_text_dict[cut_id]
fixed_item += 1
manifest_writer.write(cut)
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
fixed_text_path = args.manifest_dir + 'text.fix'
fixed_text_dict = load_fixed_text(fixed_text_path)
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
logging.info(f'Loading dev manifest from {dev_manifest_path}')
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
logging.info(f'Loading manifest from {manifest_path}')
cuts_manifest = load_manifest_lazy(manifest_path)
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
if __name__ == "__main__":
main()

View File

@ -416,3 +416,12 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
python ./local/compile_lg.py --lang-dir $lang_dir
done
fi
if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
log "Stage 23: Modify transcript according to fixed results"
# See https://github.com/wenet-e2e/WenetSpeech/discussions/54
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix -O data/fbank/text.fix
python local/fix_manifest.py \
--fixed-transcript-path data/fbank/text.fix \
--training-subset L
fi

View File

@ -390,14 +390,14 @@ class WenetSpeechAsrDataModule:
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest_lazy(
self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
self.args.manifest_dir / f"cuts_{self.args.training_subset}_fixed.jsonl.gz"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV_fixed.jsonl.gz")
@lru_cache()
def test_net_cuts(self) -> List[CutSet]:

View File

@ -44,6 +44,7 @@ from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import os
import deepspeed
import k2
import optim
@ -145,7 +146,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use.
""",
)
@ -616,7 +617,9 @@ def train_one_epoch(
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
)
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
@ -803,7 +806,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -893,6 +896,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
else:
save_checkpoint(
params=params,