mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
update speechio whisper ft results
This commit is contained in:
parent
df36f93bd8
commit
b970ba569a
@ -1,5 +1,48 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
|
||||||
|
#### Whisper
|
||||||
|
[./whisper](./whisper)
|
||||||
|
|
||||||
|
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
|
||||||
|
|
||||||
|
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||||
|
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||||
|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||||
|
| Greedy Search | 23.45 | 25.42 | 0.78 | 0.83 | 2.75 | 2.93 | 17.11 | 2.68 | 2.33 | 4.97 | 2.02 | 6.34 | 5.06 | 8.38 | 6.94 |
|
||||||
|
|
||||||
|
Command for training is:
|
||||||
|
```bash
|
||||||
|
pip install -r whisper/requirements.txt
|
||||||
|
|
||||||
|
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||||
|
|
||||||
|
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||||
|
--max-duration 200 \
|
||||||
|
--exp-dir whisper/exp_large_v2 \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--deepspeed \
|
||||||
|
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
|
Command for decoding using fine-tuned models:
|
||||||
|
```bash
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
|
||||||
|
ln -s icefall_asr_multi-hans-zh_whisper/v1/epoch-2-avg4.pt whisper/exp_large_v2/epoch-999.pt
|
||||||
|
|
||||||
|
python3 ./whisper/decode.py \
|
||||||
|
--exp-dir whisper/exp_large_v2 \
|
||||||
|
--model-name large-v2 \
|
||||||
|
--epoch 999 --avg 1 \
|
||||||
|
--beam-size 10 --max-duration 50
|
||||||
|
```
|
||||||
|
|
||||||
|
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
||||||
|
|
||||||
|
|
||||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
||||||
|
|
||||||
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
|
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
|
||||||
|
@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
log "Stage 11: Prepare WenetSpeech"
|
log "Stage 11: Prepare WenetSpeech"
|
||||||
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
|
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
|
||||||
cd data/fbank
|
cd data/fbank
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
||||||
|
|
||||||
@ -299,15 +299,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
|||||||
log "Compute KeSpeech fbank for test/dev"
|
log "Compute KeSpeech fbank for test/dev"
|
||||||
./local/compute_fbank_kespeech_dev_test.py
|
./local/compute_fbank_kespeech_dev_test.py
|
||||||
|
|
||||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
|
|
||||||
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
|
|
||||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
|
||||||
fi
|
|
||||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then
|
|
||||||
pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz")
|
|
||||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
|
||||||
fi
|
|
||||||
|
|
||||||
touch data/fbank/.kespeech.done
|
touch data/fbank/.kespeech.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
48
egs/multi_zh-hans/ASR/whisper/decode.py
Normal file → Executable file
48
egs/multi_zh-hans/ASR/whisper/decode.py
Normal file → Executable file
@ -58,6 +58,7 @@ from multi_dataset import MultiDataset
|
|||||||
from tn.chinese.normalizer import Normalizer
|
from tn.chinese.normalizer import Normalizer
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||||
from zhconv import convert
|
from zhconv import convert
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
@ -214,7 +215,7 @@ def get_parser():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="large-v2",
|
default="large-v2",
|
||||||
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -226,6 +227,13 @@ def get_parser():
|
|||||||
help="replace whisper encoder forward method to remove input length restriction",
|
help="replace whisper encoder forward method to remove input length restriction",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-distill-whisper",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use architecture of distill whisper.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -289,7 +297,6 @@ def decode_one_batch(
|
|||||||
print(hyps)
|
print(hyps)
|
||||||
return {"beam-search": hyps}
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -307,6 +314,40 @@ def decode_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "beam-search".
|
Return a dict, whose key may be "beam-search".
|
||||||
"""
|
"""
|
||||||
|
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
if normalize == "none":
|
||||||
|
return text
|
||||||
|
elif normalize == "m2met":
|
||||||
|
import re
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
return text
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
@ -331,6 +372,7 @@ def decode_dataset(
|
|||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_text = normalize_text_alimeeting(ref_text)
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
@ -430,6 +472,8 @@ def main():
|
|||||||
|
|
||||||
if params.remove_whisper_encoder_input_length_restriction:
|
if params.remove_whisper_encoder_input_length_restriction:
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
|
if params.use_distill_whisper:
|
||||||
|
replace_whisper_decoder_forward()
|
||||||
model = whisper.load_model(params.model_name, "cpu")
|
model = whisper.load_model(params.model_name, "cpu")
|
||||||
if params.epoch > 0:
|
if params.epoch > 0:
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
|
91
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Normal file → Executable file
91
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Normal file → Executable file
@ -43,7 +43,7 @@ class MultiDataset:
|
|||||||
- thchs_30_cuts_train.jsonl.gz
|
- thchs_30_cuts_train.jsonl.gz
|
||||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||||
- wenetspeech/cuts_L.jsonl.gz
|
- wenetspeech/cuts_L_fixed.jsonl.gz
|
||||||
"""
|
"""
|
||||||
self.fbank_dir = Path(fbank_dir)
|
self.fbank_dir = Path(fbank_dir)
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ class MultiDataset:
|
|||||||
# WeNetSpeech
|
# WeNetSpeech
|
||||||
logging.info("Loading WeNetSpeech in lazy mode")
|
logging.info("Loading WeNetSpeech in lazy mode")
|
||||||
wenetspeech_L_cuts = load_manifest_lazy(
|
wenetspeech_L_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
|
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
# KeSpeech
|
# KeSpeech
|
||||||
@ -124,10 +124,10 @@ class MultiDataset:
|
|||||||
aishell_4_L_cuts,
|
aishell_4_L_cuts,
|
||||||
aishell_4_M_cuts,
|
aishell_4_M_cuts,
|
||||||
aishell_4_S_cuts,
|
aishell_4_S_cuts,
|
||||||
|
alimeeting_cuts,
|
||||||
stcmds_cuts,
|
stcmds_cuts,
|
||||||
primewords_cuts,
|
primewords_cuts,
|
||||||
magicdata_cuts,
|
magicdata_cuts,
|
||||||
alimeeting_cuts,
|
|
||||||
wenetspeech_L_cuts,
|
wenetspeech_L_cuts,
|
||||||
kespeech_1_cuts,
|
kespeech_1_cuts,
|
||||||
kespeech_2_cuts,
|
kespeech_2_cuts,
|
||||||
@ -138,10 +138,10 @@ class MultiDataset:
|
|||||||
len(aishell_4_L_cuts),
|
len(aishell_4_L_cuts),
|
||||||
len(aishell_4_M_cuts),
|
len(aishell_4_M_cuts),
|
||||||
len(aishell_4_S_cuts),
|
len(aishell_4_S_cuts),
|
||||||
|
len(alimeeting_cuts),
|
||||||
len(stcmds_cuts),
|
len(stcmds_cuts),
|
||||||
len(primewords_cuts),
|
len(primewords_cuts),
|
||||||
len(magicdata_cuts),
|
len(magicdata_cuts),
|
||||||
len(alimeeting_cuts),
|
|
||||||
len(wenetspeech_L_cuts),
|
len(wenetspeech_L_cuts),
|
||||||
len(kespeech_1_cuts),
|
len(kespeech_1_cuts),
|
||||||
len(kespeech_2_cuts),
|
len(kespeech_2_cuts),
|
||||||
@ -151,55 +151,13 @@ class MultiDataset:
|
|||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset dev cuts")
|
logging.info("About to get multidataset dev cuts")
|
||||||
|
|
||||||
# AISHELL
|
|
||||||
logging.info("Loading Aishell DEV set in lazy mode")
|
|
||||||
aishell_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 DEV set in lazy mode")
|
|
||||||
aishell2_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting DEV set in lazy mode")
|
|
||||||
alimeeting_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData DEV set in lazy mode")
|
|
||||||
magicdata_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# KeSpeech
|
|
||||||
logging.info("Loading KeSpeech DEV set in lazy mode")
|
|
||||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
# WeNetSpeech
|
||||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
||||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
return wenetspeech_dev_cuts
|
return wenetspeech_dev_cuts
|
||||||
# return [
|
|
||||||
# aishell_dev_cuts,
|
|
||||||
# aishell2_dev_cuts,
|
|
||||||
# alimeeting_dev_cuts,
|
|
||||||
# magicdata_dev_cuts,
|
|
||||||
# kespeech_dev_phase1_cuts,
|
|
||||||
# kespeech_dev_phase2_cuts,
|
|
||||||
# wenetspeech_dev_cuts,
|
|
||||||
# ]
|
|
||||||
|
|
||||||
def test_cuts(self) -> Dict[str, CutSet]:
|
def test_cuts(self) -> Dict[str, CutSet]:
|
||||||
logging.info("About to get multidataset test cuts")
|
logging.info("About to get multidataset test cuts")
|
||||||
@ -267,30 +225,23 @@ class MultiDataset:
|
|||||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
||||||
)
|
)
|
||||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"aishell-2_test": aishell2_test_cuts,
|
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||||
"aishell-4": aishell4_test_cuts,
|
# "aishell_test": aishell_test_cuts,
|
||||||
"magicdata_test": magicdata_test_cuts,
|
# "aishell_dev": aishell_dev_cuts,
|
||||||
"kespeech-asr_test": kespeech_test_cuts,
|
# "ali-meeting_test": alimeeting_test_cuts,
|
||||||
|
# "ali-meeting_eval": alimeeting_eval_cuts,
|
||||||
|
# "aishell-4_test": aishell4_test_cuts,
|
||||||
|
# "aishell-2_test": aishell2_test_cuts,
|
||||||
|
# "aishell-2_dev": aishell2_dev_cuts,
|
||||||
|
# "magicdata_test": magicdata_test_cuts,
|
||||||
|
# "magicdata_dev": magicdata_dev_cuts,
|
||||||
|
# "kespeech-asr_test": kespeech_test_cuts,
|
||||||
|
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
||||||
|
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||||
|
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||||
|
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
||||||
}
|
}
|
||||||
|
|
||||||
# return {
|
|
||||||
# "alimeeting_test": alimeeting_test_cuts,
|
|
||||||
# "alimeeting_eval": alimeeting_eval_cuts,
|
|
||||||
# "aishell_test": aishell_test_cuts,
|
|
||||||
# "aishell_dev": aishell_dev_cuts,
|
|
||||||
# "aishell-2_test": aishell2_test_cuts,
|
|
||||||
# "aishell-2_dev": aishell2_dev_cuts,
|
|
||||||
# "aishell-4": aishell4_test_cuts,
|
|
||||||
# "magicdata_test": magicdata_test_cuts,
|
|
||||||
# "magicdata_dev": magicdata_dev_cuts,
|
|
||||||
# "kespeech-asr_test": kespeech_test_cuts,
|
|
||||||
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
|
||||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
|
||||||
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
|
||||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
|
||||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
|
||||||
# }
|
|
||||||
|
51
egs/multi_zh-hans/ASR/whisper/train.py
Normal file → Executable file
51
egs/multi_zh-hans/ASR/whisper/train.py
Normal file → Executable file
@ -66,6 +66,7 @@ from torch.nn.functional import pad as pad_tensor
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
@ -146,7 +147,7 @@ def get_parser():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="large-v2",
|
default="large-v2",
|
||||||
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -232,6 +233,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-distill-whisper",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use architecture of distill whisper.",
|
||||||
|
)
|
||||||
|
|
||||||
parser = deepspeed.add_config_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -441,6 +449,41 @@ def compute_loss(
|
|||||||
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
||||||
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
||||||
|
|
||||||
|
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
if normalize == "none":
|
||||||
|
return text
|
||||||
|
elif normalize == "m2met":
|
||||||
|
import re
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
return text
|
||||||
|
|
||||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||||
@ -459,7 +502,7 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
# remove spaces in texts
|
# remove spaces in texts
|
||||||
texts = [text.replace(" ", "") for text in texts]
|
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||||
|
|
||||||
text_tokens_list = [
|
text_tokens_list = [
|
||||||
list(tokenizer.sot_sequence_including_notimestamps)
|
list(tokenizer.sot_sequence_including_notimestamps)
|
||||||
@ -759,6 +802,8 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
|
if params.use_distill_whisper:
|
||||||
|
replace_whisper_decoder_forward()
|
||||||
model = whisper.load_model(params.model_name, "cpu")
|
model = whisper.load_model(params.model_name, "cpu")
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
|
|
||||||
@ -824,7 +869,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
512
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import whisper
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
from typing import Dict, Iterable, Optional
|
||||||
|
from whisper.model import ResidualAttentionBlock, LayerNorm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||||
|
the text tokens
|
||||||
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||||
|
the encoded audio features to be attended on
|
||||||
|
"""
|
||||||
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
|
x = (
|
||||||
|
self.token_embedding(x)
|
||||||
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
|
)
|
||||||
|
x = (
|
||||||
|
x
|
||||||
|
+ self.positional_embedding[offset : offset + x.shape[1]]
|
||||||
|
)
|
||||||
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
|
# for block in self.blocks:
|
||||||
|
# x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
# use architecture from the distill whisper model
|
||||||
|
# see https://github.com/huggingface/distil-whisper
|
||||||
|
x = self.blocks[0](x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
x = self.blocks[-1](x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
|
||||||
|
x = self.ln(x)
|
||||||
|
logits = (
|
||||||
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
|
).float()
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def replace_whisper_decoder_forward():
|
||||||
|
"""
|
||||||
|
This function monkey patches the forward method of the whisper encoder.
|
||||||
|
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
|
||||||
|
"""
|
||||||
|
whisper.model.TextDecoder.forward = forward
|
@ -2,41 +2,70 @@
|
|||||||
|
|
||||||
### SpeechIO Test Set Decoding Results
|
### SpeechIO Test Set Decoding Results
|
||||||
|
|
||||||
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
|
|
||||||
|
|
||||||
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|
|
||||||
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
|
|
||||||
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
|
|
||||||
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
|
|
||||||
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
|
|
||||||
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
|
|
||||||
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
|
|
||||||
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
|
|
||||||
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
|
|
||||||
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
|
|
||||||
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
|
|
||||||
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
|
|
||||||
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
|
|
||||||
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
|
|
||||||
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
|
|
||||||
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
|
|
||||||
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
|
|
||||||
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
|
|
||||||
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
|
|
||||||
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
|
|
||||||
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
|
|
||||||
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
|
|
||||||
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
|
|
||||||
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
|
|
||||||
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
|
|
||||||
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
|
|
||||||
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
|
|
||||||
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
|
|
||||||
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
|
|
||||||
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### **Unlocked** SpeechIO test sets (ZH00001 ~ ZH00026)
|
||||||
|
| Rank 排名 | Model 模型 | CER 字错误率 | Date 时间 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| 1 | ximalaya_api_zh | 1.72% | 2023.12 |
|
||||||
|
| 2 | aliyun_ftasr_api_zh | 1.85% | 2023.12 |
|
||||||
|
| 3 | microsoft_batch_zh | 2.40% | 2023.12 |
|
||||||
|
| 4 | bilibili_api_zh | 2.90% | 2023.09 |
|
||||||
|
| 5 | tencent_api_zh | 3.18% | 2023.12 |
|
||||||
|
| 6 | iflytek_lfasr_api_zh | 3.32% | 2023.12 |
|
||||||
|
| 7 | aispeech_api_zh | 3.62% | 2023.12 |
|
||||||
|
| 8 | **whisper-large-ft-v1** | **4.45%** | 2024.04 |
|
||||||
|
| 9 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
|
||||||
|
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
|
||||||
|
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
|
||||||
|
|
||||||
|
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
|
||||||
|
|
||||||
|
<details><summary> Detail all models </summary><p>
|
||||||
|
|
||||||
|
| Model | Training Set | Note |
|
||||||
|
|----------------------------------------------------------------------------------------------------------|---------------|-----------------------------------------------------|
|
||||||
|
|[zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24)| multi-hans-zh | decoding with transducer head and blank penalty 2.0 |
|
||||||
|
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper)| wenetspeech | |
|
||||||
|
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py)|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details><summary> Detail all results (字错误率 CER %) </summary><p>
|
||||||
|
|
||||||
|
| Test Set ID | 测试场景&内容领域 | bilibili_api_zh (2023.09) | whisper-large-ft-v0 | whisper-large-ft-v1 | zipformer |
|
||||||
|
|----------------------|-------------------------------|-----------------|---------|-----------|-----------|
|
||||||
|
| Avg (01-26) | | 2.9 | 6.34 | 4.45 | 6.17 |
|
||||||
|
| SPEECHIO_ASR_ZH00001 | 新闻联播 | 0.54 | 1.42 | 1.11 | 1.37 |
|
||||||
|
| SPEECHIO_ASR_ZH00002 | 访谈 鲁豫有约 | 2.78 | 4.76 | 3.25 | 4.67 |
|
||||||
|
| SPEECHIO_ASR_ZH00003 | 电视节目 天下足球 | 0.81 | 2.17 | 1.64 | 2.71 |
|
||||||
|
| SPEECHIO_ASR_ZH00004 | 场馆演讲 罗振宇跨年 | 1.48 | 2.53 | 1.91 | 2.54 |
|
||||||
|
| SPEECHIO_ASR_ZH00005 | 在线教育 李永乐 科普 | 1.47 | 4.27 | 1.98 | 3.12 |
|
||||||
|
| SPEECHIO_ASR_ZH00006 | 直播 王者荣耀 张大仙&骚白 | 5.85 | 12.55 | 9.28 | 12.86 |
|
||||||
|
| SPEECHIO_ASR_ZH00007 | 直播 带货 李佳琪&薇娅 | 6.19 | 13.38 | 11.14 | 14.58 |
|
||||||
|
| SPEECHIO_ASR_ZH00008 | 线下培训 老罗语录 | 3.68 | 9.56 | 7.47 | 9.05 |
|
||||||
|
| SPEECHIO_ASR_ZH00009 | 播客 故事FM | 3.18 | 5.66 | 3.78 | 5.4 |
|
||||||
|
| SPEECHIO_ASR_ZH00010 | 播客 创业内幕 | 3.51 | 7.84 | 4.32 | 6.4 |
|
||||||
|
| SPEECHIO_ASR_ZH00011 | 在线教育 罗翔 刑法法考 | 1.77 | 3.22 | 2.49 | 3.12 |
|
||||||
|
| SPEECHIO_ASR_ZH00012 | 在线教育 张雪峰 考研 | 2.11 | 5.98 | 3.09 | 4.41 |
|
||||||
|
| SPEECHIO_ASR_ZH00013 | 短视频 影剪 谷阿莫&牛叔说电影 | 2.97 | 5.91 | 3.78 | 6.56 |
|
||||||
|
| SPEECHIO_ASR_ZH00014 | 短视频 美式&烹饪 | 3.56 | 6.03 | 5.28 | 8.14 |
|
||||||
|
| SPEECHIO_ASR_ZH00015 | 评书 单田芳 白眉大侠 | 4.72 | 8.77 | 7.97 | 9.1 |
|
||||||
|
| SPEECHIO_ASR_ZH00016 | 相声 德云社专场 | 3.01 | 5.24 | 4.41 | 5.59 |
|
||||||
|
| SPEECHIO_ASR_ZH00017 | 脱口秀 吐槽大会 | 2.93 | 7.05 | 3.27 | 5.17 |
|
||||||
|
| SPEECHIO_ASR_ZH00018 | 少儿卡通 小猪佩奇&熊出没 | 1.98 | 3.53 | 3.21 | 4.15 |
|
||||||
|
| SPEECHIO_ASR_ZH00019 | 体育赛事解说 NBA比赛 | 2.32 | 6.89 | 4.24 | 6.66 |
|
||||||
|
| SPEECHIO_ASR_ZH00020 | 纪录片 篮球人物 | 1.51 | 4.16 | 2.96 | 4.2 |
|
||||||
|
| SPEECHIO_ASR_ZH00021 | 短视频 汽车之家 汽车评测 | 1.75 | 4.77 | 2.77 | 4.17 |
|
||||||
|
| SPEECHIO_ASR_ZH00022 | 短视频 小艾大叔 豪宅带看 | 3.29 | 6.35 | 5.66 | 6.72 |
|
||||||
|
| SPEECHIO_ASR_ZH00023 | 短视频 开箱视频 Zeal&无聊开箱 | 2.18 | 8.99 | 4.45 | 7.94 |
|
||||||
|
| SPEECHIO_ASR_ZH00024 | 短视频 付老师 农业种植 | 4.80 | 10.81 | 6.25 | 8.64 |
|
||||||
|
| SPEECHIO_ASR_ZH00025 | 线下课堂 石国鹏 古希腊哲学 | 3.32 | 8.41 | 5.8 | 8.54 |
|
||||||
|
| SPEECHIO_ASR_ZH00026 | 广播电台节目 张震鬼故事 | 3.70 | 4.52 | 4.11 | 4.67 |
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
Command for decoding using fine-tuned whisper:
|
Command for decoding using fine-tuned whisper:
|
||||||
```bash
|
```bash
|
||||||
@ -76,16 +105,6 @@ mv words.txt ./data/lang_bpe_2000/
|
|||||||
--manifest-dir data/fbank_kaldi \
|
--manifest-dir data/fbank_kaldi \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
```
|
```
|
||||||
Command for fusion the above decoding results from whisper and zipformer:
|
|
||||||
```bash
|
|
||||||
python local/whisper_zipformer_fusion.py \
|
|
||||||
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
|
|
||||||
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
|
|
||||||
--output-log-dir ./results_fusion
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
|
|
||||||
|
|
||||||
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
||||||
are available at
|
are available at
|
||||||
|
104
egs/speechio/ASR/local/whisper_zipformer_fusion.py → egs/speechio/ASR/local/normalize_results.py
Normal file → Executable file
104
egs/speechio/ASR/local/whisper_zipformer_fusion.py → egs/speechio/ASR/local/normalize_results.py
Normal file → Executable file
@ -21,7 +21,7 @@ Since whisper model is more likely to make deletion errors and zipformer model i
|
|||||||
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
|
python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -31,38 +31,43 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import kaldialign
|
import kaldialign
|
||||||
|
|
||||||
from icefall.utils import store_transcripts, write_error_stats
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
from speechio_norm import TextNorm
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--whisper-log-dir",
|
"--model-log-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./recogs_whisper",
|
default="./recogs_whisper",
|
||||||
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--zipformer-log-dir",
|
|
||||||
type=str,
|
|
||||||
default="./recogs_zipformer",
|
|
||||||
help="The directory to store the zipformer logs",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-log-dir",
|
"--output-log-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="./results_fusion",
|
default="./results_whisper_norm",
|
||||||
help="The directory to store the fusion logs",
|
help="The directory to store the normalized whisper logs",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_results_with_speechio_text_norm(
|
||||||
res_dir: Path,
|
res_dir: Path,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
|
normalizer = TextNorm()
|
||||||
|
# normlize items in results_dict
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
results_norm = []
|
||||||
|
for item in results:
|
||||||
|
wav_name, ref, hyp = item
|
||||||
|
ref = normalizer(ref)
|
||||||
|
hyp = normalizer(hyp)
|
||||||
|
results_norm.append((wav_name, ref, hyp))
|
||||||
|
results_dict[key] = results_norm
|
||||||
|
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
|
|
||||||
suffix = "epoch-999-avg-1"
|
suffix = "epoch-999-avg-1"
|
||||||
@ -120,11 +125,9 @@ def extract_hyp_ref_wavname(filename):
|
|||||||
return hyps, refs, wav_name
|
return hyps, refs, wav_name
|
||||||
|
|
||||||
|
|
||||||
def get_pair_filenames(
|
def get_filenames(
|
||||||
whisper_log_dir,
|
whisper_log_dir,
|
||||||
zipformer_log_dir,
|
|
||||||
whisper_suffix="beam-search-epoch-999-avg-1",
|
whisper_suffix="beam-search-epoch-999-avg-1",
|
||||||
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
|
|
||||||
):
|
):
|
||||||
results = []
|
results = []
|
||||||
start_index, end_index = 0, 26
|
start_index, end_index = 0, 26
|
||||||
@ -134,80 +137,23 @@ def get_pair_filenames(
|
|||||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||||
for partition in dataset_parts:
|
for partition in dataset_parts:
|
||||||
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
||||||
zipformer_filename = (
|
results.append(whisper_filename)
|
||||||
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
|
|
||||||
)
|
|
||||||
results.append((whisper_filename, zipformer_filename))
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def fusion_hyps_trust_substituion_insertion(
|
|
||||||
hyps_whisper, hyps_zipformer, refs, ERR="*"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
alignment example:
|
|
||||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
|
||||||
left is whisper, right is zipformer
|
|
||||||
for whisper substitution, use left
|
|
||||||
for whisper insertion, use left
|
|
||||||
for whisper deletion, use right
|
|
||||||
"""
|
|
||||||
hyps_fusion = []
|
|
||||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
|
||||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
|
||||||
hyp_f = ""
|
|
||||||
for a in ali:
|
|
||||||
if a[0] == ERR:
|
|
||||||
hyp_f += a[1]
|
|
||||||
else:
|
|
||||||
hyp_f += a[0]
|
|
||||||
hyps_fusion.append(hyp_f)
|
|
||||||
return hyps_fusion
|
|
||||||
|
|
||||||
|
|
||||||
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
|
|
||||||
"""
|
|
||||||
alignment example:
|
|
||||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
|
||||||
left is whisper, right is zipformer
|
|
||||||
for whisper substitution, use left
|
|
||||||
for whisper insertion, use right
|
|
||||||
for whisper deletion, use right
|
|
||||||
"""
|
|
||||||
hyps_fusion = []
|
|
||||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
|
||||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
|
||||||
hyp_f = ""
|
|
||||||
for a in ali:
|
|
||||||
if a[0] == ERR:
|
|
||||||
hyp_f += a[1]
|
|
||||||
elif a[1] == ERR:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
hyp_f += a[0]
|
|
||||||
hyps_fusion.append(hyp_f)
|
|
||||||
return hyps_fusion
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# mkdir output_log_dir
|
# mkdir output_log_dir
|
||||||
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
||||||
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
|
filenames = get_filenames(args.model_log_dir)
|
||||||
for pair in pair_logs:
|
for filename in filenames:
|
||||||
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
|
hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
|
||||||
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
|
partition_name = filename.split("/")[-1].split("-")[1]
|
||||||
|
|
||||||
hyps_fusion = fusion_hyps_trust_substituion_insertion(
|
save_results_with_speechio_text_norm(
|
||||||
hyps_whisper, hyps_zipformer, refs
|
|
||||||
)
|
|
||||||
|
|
||||||
partition_name = pair[0].split("/")[-1].split("-")[1]
|
|
||||||
save_results(
|
|
||||||
Path(args.output_log_dir),
|
Path(args.output_log_dir),
|
||||||
partition_name,
|
partition_name,
|
||||||
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
|
{"norm": list(zip(wav_name, refs, hyps))},
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Processed {partition_name}")
|
print(f"Processed {partition_name}")
|
1203
egs/speechio/ASR/local/speechio_norm.py
Executable file
1203
egs/speechio/ASR/local/speechio_norm.py
Executable file
File diff suppressed because it is too large
Load Diff
114
egs/wenetspeech/ASR/local/fix_manifest.py
Normal file
114
egs/wenetspeech/ASR/local/fix_manifest.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 author: Yuekai Zhang
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fixed-transcript-path",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank/text.fix",
|
||||||
|
help="""
|
||||||
|
See https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||||
|
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/fbank/",
|
||||||
|
help="Directory to store the manifest files",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--training-subset",
|
||||||
|
type=str,
|
||||||
|
default="L",
|
||||||
|
help="The training subset for wenetspeech.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def load_fixed_text(fixed_text_path):
|
||||||
|
"""
|
||||||
|
fixed text format
|
||||||
|
X0000016287_92761015_S00001 我是徐涛
|
||||||
|
X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯
|
||||||
|
load into a dict
|
||||||
|
"""
|
||||||
|
fixed_text_dict = {}
|
||||||
|
with open(fixed_text_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
cut_id, text = line.strip().split(' ', 1)
|
||||||
|
fixed_text_dict[cut_id] = text
|
||||||
|
return fixed_text_dict
|
||||||
|
|
||||||
|
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
|
||||||
|
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
||||||
|
fixed_item = 0
|
||||||
|
for i, cut in enumerate(manifest):
|
||||||
|
if i % 10000 == 0:
|
||||||
|
logging.info(f'Processing cut {i}, fixed {fixed_item}')
|
||||||
|
cut_id_orgin = cut.id
|
||||||
|
if cut_id_orgin.endswith('_sp0.9'):
|
||||||
|
cut_id = cut_id_orgin[:-6]
|
||||||
|
elif cut_id_orgin.endswith('_sp1.1'):
|
||||||
|
cut_id = cut_id_orgin[:-6]
|
||||||
|
else:
|
||||||
|
cut_id = cut_id_orgin
|
||||||
|
if cut_id in fixed_text_dict:
|
||||||
|
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
|
||||||
|
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
|
||||||
|
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
|
||||||
|
cut.supervisions[0].text = fixed_text_dict[cut_id]
|
||||||
|
fixed_item += 1
|
||||||
|
manifest_writer.write(cut)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
fixed_text_path = args.manifest_dir + 'text.fix'
|
||||||
|
fixed_text_dict = load_fixed_text(fixed_text_path)
|
||||||
|
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
|
||||||
|
|
||||||
|
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
|
||||||
|
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
|
||||||
|
logging.info(f'Loading dev manifest from {dev_manifest_path}')
|
||||||
|
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
|
||||||
|
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
|
||||||
|
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
|
||||||
|
|
||||||
|
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
|
||||||
|
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
|
||||||
|
logging.info(f'Loading manifest from {manifest_path}')
|
||||||
|
cuts_manifest = load_manifest_lazy(manifest_path)
|
||||||
|
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
|
||||||
|
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -416,3 +416,12 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
|
|||||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
|
||||||
|
log "Stage 23: Modify transcript according to fixed results"
|
||||||
|
# See https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||||
|
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix -O data/fbank/text.fix
|
||||||
|
python local/fix_manifest.py \
|
||||||
|
--fixed-transcript-path data/fbank/text.fix \
|
||||||
|
--training-subset L
|
||||||
|
fi
|
@ -390,14 +390,14 @@ class WenetSpeechAsrDataModule:
|
|||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
cuts_train = load_manifest_lazy(
|
cuts_train = load_manifest_lazy(
|
||||||
self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
|
self.args.manifest_dir / f"cuts_{self.args.training_subset}_fixed.jsonl.gz"
|
||||||
)
|
)
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV_fixed.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_net_cuts(self) -> List[CutSet]:
|
def test_net_cuts(self) -> List[CutSet]:
|
||||||
|
@ -44,6 +44,7 @@ from pathlib import Path
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import os
|
||||||
import deepspeed
|
import deepspeed
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -145,7 +146,7 @@ def get_parser():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="large-v2",
|
default="large-v2",
|
||||||
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -616,7 +617,9 @@ def train_one_epoch(
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||||
)
|
)
|
||||||
|
os.system(
|
||||||
|
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -803,7 +806,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
512
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -893,6 +896,7 @@ def run(rank, world_size, args):
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
)
|
)
|
||||||
|
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||||
else:
|
else:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user