mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Update results using Zipformer-large on multi-hans-zh (#1679)
This commit is contained in:
parent
2d64228efa
commit
1c3d992a39
@ -43,6 +43,61 @@ Fine-tuned models, training logs, decoding logs, tensorboard and decoding result
|
|||||||
are available at
|
are available at
|
||||||
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
||||||
|
|
||||||
|
### Multi Chinese datasets char-based training results (streaming) on zipformer large model
|
||||||
|
|
||||||
|
#### Streaming (with CTC head)
|
||||||
|
|
||||||
|
The training command for large model (num of params : ~160M):
|
||||||
|
|
||||||
|
Please use the [script](https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/prepare.sh) to prepare fbank features.
|
||||||
|
|
||||||
|
```
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--max-duration 1200 \
|
||||||
|
--num-workers 8 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--exp-dir zipformer/exp-large \
|
||||||
|
--causal 1 \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 768,1024,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 256,384,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
The decoding command for transducer greedy search:
|
||||||
|
|
||||||
|
```
|
||||||
|
./zipformer/decode.py \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--causal 1 \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--chunk_size -1
|
||||||
|
--left-context-frames -1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--exp-dir zipformer/exp-large \
|
||||||
|
--max-duration 1200 \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 768,1024,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 256,384,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
|
```
|
||||||
|
|
||||||
|
Character Error Rates (CERs) listed below are produced by the checkpoint of the 18th epoch using BPE model ( # tokens is 2000, byte fallback enabled).
|
||||||
|
|
||||||
|
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||||
|
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||||
|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||||
|
| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 |
|
||||||
|
| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 |
|
||||||
|
| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 |
|
||||||
|
| Transducer Greedy Streaming | 26.83|28.74 | 1.75 | 1.91 | 3.84 | 4.12 | 17.83 | 3.23 | 2.71 | 7.31 | 3.16 | 8.69 | 5.71 | 7.91 | 8.54 |
|
||||||
|
|
||||||
|
Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-large
|
||||||
|
|
||||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
||||||
|
|
||||||
|
@ -1,247 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import lhotse
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDataset:
|
|
||||||
def __init__(self, fbank_dir: str):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
manifest_dir:
|
|
||||||
It is expected to contain the following files:
|
|
||||||
- aishell_cuts_train.jsonl.gz
|
|
||||||
- aishell2_cuts_train.jsonl.gz
|
|
||||||
- aishell4_cuts_train_L.jsonl.gz
|
|
||||||
- aishell4_cuts_train_M.jsonl.gz
|
|
||||||
- aishell4_cuts_train_S.jsonl.gz
|
|
||||||
- alimeeting-far_cuts_train.jsonl.gz
|
|
||||||
- magicdata_cuts_train.jsonl.gz
|
|
||||||
- primewords_cuts_train.jsonl.gz
|
|
||||||
- stcmds_cuts_train.jsonl.gz
|
|
||||||
- thchs_30_cuts_train.jsonl.gz
|
|
||||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
|
||||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
|
||||||
- wenetspeech/cuts_L_fixed.jsonl.gz
|
|
||||||
"""
|
|
||||||
self.fbank_dir = Path(fbank_dir)
|
|
||||||
|
|
||||||
def train_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get multidataset train cuts")
|
|
||||||
|
|
||||||
# THCHS-30
|
|
||||||
logging.info("Loading THCHS-30 in lazy mode")
|
|
||||||
thchs_30_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-1
|
|
||||||
logging.info("Loading Aishell-1 in lazy mode")
|
|
||||||
aishell_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 in lazy mode")
|
|
||||||
aishell_2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-4
|
|
||||||
logging.info("Loading Aishell-4 in lazy mode")
|
|
||||||
aishell_4_L_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_4_M_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_4_S_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ST-CMDS
|
|
||||||
logging.info("Loading ST-CMDS in lazy mode")
|
|
||||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
|
||||||
|
|
||||||
# Primewords
|
|
||||||
logging.info("Loading Primewords in lazy mode")
|
|
||||||
primewords_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData in lazy mode")
|
|
||||||
magicdata_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting in lazy mode")
|
|
||||||
alimeeting_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech in lazy mode")
|
|
||||||
wenetspeech_L_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# KeSpeech
|
|
||||||
logging.info("Loading KeSpeech in lazy mode")
|
|
||||||
kespeech_1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return CutSet.mux(
|
|
||||||
thchs_30_cuts,
|
|
||||||
aishell_cuts,
|
|
||||||
aishell_2_cuts,
|
|
||||||
aishell_4_L_cuts,
|
|
||||||
aishell_4_M_cuts,
|
|
||||||
aishell_4_S_cuts,
|
|
||||||
alimeeting_cuts,
|
|
||||||
stcmds_cuts,
|
|
||||||
primewords_cuts,
|
|
||||||
magicdata_cuts,
|
|
||||||
wenetspeech_L_cuts,
|
|
||||||
kespeech_1_cuts,
|
|
||||||
kespeech_2_cuts,
|
|
||||||
weights=[
|
|
||||||
len(thchs_30_cuts),
|
|
||||||
len(aishell_cuts),
|
|
||||||
len(aishell_2_cuts),
|
|
||||||
len(aishell_4_L_cuts),
|
|
||||||
len(aishell_4_M_cuts),
|
|
||||||
len(aishell_4_S_cuts),
|
|
||||||
len(alimeeting_cuts),
|
|
||||||
len(stcmds_cuts),
|
|
||||||
len(primewords_cuts),
|
|
||||||
len(magicdata_cuts),
|
|
||||||
len(wenetspeech_L_cuts),
|
|
||||||
len(kespeech_1_cuts),
|
|
||||||
len(kespeech_2_cuts),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def dev_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get multidataset dev cuts")
|
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
|
||||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return wenetspeech_dev_cuts
|
|
||||||
|
|
||||||
def test_cuts(self) -> Dict[str, CutSet]:
|
|
||||||
logging.info("About to get multidataset test cuts")
|
|
||||||
|
|
||||||
# AISHELL
|
|
||||||
logging.info("Loading Aishell set in lazy mode")
|
|
||||||
aishell_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 set in lazy mode")
|
|
||||||
aishell2_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell2_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-4
|
|
||||||
logging.info("Loading Aishell-4 TEST set in lazy mode")
|
|
||||||
aishell4_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting set in lazy mode")
|
|
||||||
alimeeting_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
alimeeting_eval_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData set in lazy mode")
|
|
||||||
magicdata_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
magicdata_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# KeSpeech
|
|
||||||
logging.info("Loading KeSpeech set in lazy mode")
|
|
||||||
kespeech_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
|
||||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
|
||||||
)
|
|
||||||
wenetspeech_test_net_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
|
||||||
)
|
|
||||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
|
||||||
# "aishell_test": aishell_test_cuts,
|
|
||||||
# "aishell_dev": aishell_dev_cuts,
|
|
||||||
# "ali-meeting_test": alimeeting_test_cuts,
|
|
||||||
# "ali-meeting_eval": alimeeting_eval_cuts,
|
|
||||||
# "aishell-4_test": aishell4_test_cuts,
|
|
||||||
# "aishell-2_test": aishell2_test_cuts,
|
|
||||||
# "aishell-2_dev": aishell2_dev_cuts,
|
|
||||||
# "magicdata_test": magicdata_test_cuts,
|
|
||||||
# "magicdata_dev": magicdata_dev_cuts,
|
|
||||||
# "kespeech-asr_test": kespeech_test_cuts,
|
|
||||||
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
|
||||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
|
||||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
|
||||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
|
||||||
}
|
|
1
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Symbolic link
1
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
|
@ -46,7 +46,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -367,21 +367,18 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
|
||||||
H=H,
|
H=H,
|
||||||
bpe_model=bpe_model,
|
bpe_model=bpe_model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
word_table=word_table,
|
|
||||||
G=G,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
ref_words = list(ref_text.replace(" ", ""))
|
ref_text = normalize_text_alimeeting(ref_text)
|
||||||
hyp_words = list("".join(hyp_words))
|
hyp_text = "".join(hyp_words)
|
||||||
this_batch.append((cut_id, ref_words, hyp_words))
|
this_batch.append((cut_id, ref_text, hyp_text))
|
||||||
|
|
||||||
results[name].extend(this_batch)
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
@ -583,7 +580,7 @@ def main():
|
|||||||
data_module = AsrDataModule(args)
|
data_module = AsrDataModule(args)
|
||||||
multi_dataset = MultiDataset(args.manifest_dir)
|
multi_dataset = MultiDataset(args.manifest_dir)
|
||||||
|
|
||||||
test_sets_cuts = multi_dataset.test_cuts()
|
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
|
@ -118,7 +118,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from multi_dataset import MultiDataset
|
from multi_dataset import MultiDataset
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params, normalize_text_alimeeting
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -532,7 +532,6 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
texts = [list(str(text).replace(" ", "")) for text in texts]
|
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -548,6 +547,7 @@ def decode_dataset(
|
|||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_text = normalize_text_alimeeting(ref_text)
|
||||||
hyp_text = "".join(hyp_words)
|
hyp_text = "".join(hyp_words)
|
||||||
this_batch.append((cut_id, ref_text, hyp_text))
|
this_batch.append((cut_id, ref_text, hyp_text))
|
||||||
|
|
||||||
@ -795,7 +795,7 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
test_sets_cuts = multi_dataset.test_cuts()
|
test_sets_cuts = {**multi_dataset.test_cuts(), **multi_dataset.speechio_test_cuts()}
|
||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_dl = [
|
test_dl = [
|
||||||
|
@ -1,316 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import lhotse
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDataset:
|
|
||||||
def __init__(self, fbank_dir: str):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
manifest_dir:
|
|
||||||
It is expected to contain the following files:
|
|
||||||
- aidatatang_cuts_train.jsonl.gz
|
|
||||||
- aishell_cuts_train.jsonl.gz
|
|
||||||
- aishell2_cuts_train.jsonl.gz
|
|
||||||
- aishell4_cuts_train_L.jsonl.gz
|
|
||||||
- aishell4_cuts_train_M.jsonl.gz
|
|
||||||
- aishell4_cuts_train_S.jsonl.gz
|
|
||||||
- alimeeting-far_cuts_train.jsonl.gz
|
|
||||||
- magicdata_cuts_train.jsonl.gz
|
|
||||||
- primewords_cuts_train.jsonl.gz
|
|
||||||
- stcmds_cuts_train.jsonl.gz
|
|
||||||
- thchs_30_cuts_train.jsonl.gz
|
|
||||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
|
||||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
|
||||||
- wenetspeech/cuts_L.jsonl.gz
|
|
||||||
"""
|
|
||||||
self.fbank_dir = Path(fbank_dir)
|
|
||||||
|
|
||||||
def train_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get multidataset train cuts")
|
|
||||||
|
|
||||||
# THCHS-30
|
|
||||||
logging.info("Loading THCHS-30 in lazy mode")
|
|
||||||
thchs_30_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "thchs_30_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-1
|
|
||||||
logging.info("Loading Aishell-1 in lazy mode")
|
|
||||||
aishell_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 in lazy mode")
|
|
||||||
aishell_2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-4
|
|
||||||
logging.info("Loading Aishell-4 in lazy mode")
|
|
||||||
aishell_4_L_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_L.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_4_M_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_M.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_4_S_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_train_S.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ST-CMDS
|
|
||||||
logging.info("Loading ST-CMDS in lazy mode")
|
|
||||||
stcmds_cuts = load_manifest_lazy(self.fbank_dir / "stcmds_cuts_train.jsonl.gz")
|
|
||||||
|
|
||||||
# Primewords
|
|
||||||
logging.info("Loading Primewords in lazy mode")
|
|
||||||
primewords_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "primewords_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData in lazy mode")
|
|
||||||
magicdata_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Aidatatang_200zh
|
|
||||||
logging.info("Loading Aidatatang_200zh in lazy mode")
|
|
||||||
aidatatang_200zh_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aidatatang_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting in lazy mode")
|
|
||||||
alimeeting_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech in lazy mode")
|
|
||||||
wenetspeech_L_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# KeSpeech
|
|
||||||
logging.info("Loading KeSpeech in lazy mode")
|
|
||||||
kespeech_1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_train_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return CutSet.mux(
|
|
||||||
thchs_30_cuts,
|
|
||||||
aishell_cuts,
|
|
||||||
aishell_2_cuts,
|
|
||||||
aishell_4_L_cuts,
|
|
||||||
aishell_4_M_cuts,
|
|
||||||
aishell_4_S_cuts,
|
|
||||||
stcmds_cuts,
|
|
||||||
primewords_cuts,
|
|
||||||
magicdata_cuts,
|
|
||||||
aidatatang_200zh_cuts,
|
|
||||||
alimeeting_cuts,
|
|
||||||
wenetspeech_L_cuts,
|
|
||||||
kespeech_1_cuts,
|
|
||||||
kespeech_2_cuts,
|
|
||||||
weights=[
|
|
||||||
len(thchs_30_cuts),
|
|
||||||
len(aishell_cuts),
|
|
||||||
len(aishell_2_cuts),
|
|
||||||
len(aishell_4_L_cuts),
|
|
||||||
len(aishell_4_M_cuts),
|
|
||||||
len(aishell_4_S_cuts),
|
|
||||||
len(stcmds_cuts),
|
|
||||||
len(primewords_cuts),
|
|
||||||
len(magicdata_cuts),
|
|
||||||
len(aidatatang_200zh_cuts),
|
|
||||||
len(alimeeting_cuts),
|
|
||||||
len(wenetspeech_L_cuts),
|
|
||||||
len(kespeech_1_cuts),
|
|
||||||
len(kespeech_2_cuts),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def dev_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get multidataset dev cuts")
|
|
||||||
|
|
||||||
# Aidatatang_200zh
|
|
||||||
logging.info("Loading Aidatatang_200zh DEV set in lazy mode")
|
|
||||||
aidatatang_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL
|
|
||||||
logging.info("Loading Aishell DEV set in lazy mode")
|
|
||||||
aishell_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 DEV set in lazy mode")
|
|
||||||
aishell2_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting DEV set in lazy mode")
|
|
||||||
alimeeting_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData DEV set in lazy mode")
|
|
||||||
magicdata_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# KeSpeech
|
|
||||||
logging.info("Loading KeSpeech DEV set in lazy mode")
|
|
||||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
|
||||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return wenetspeech_dev_cuts
|
|
||||||
# return [
|
|
||||||
# aidatatang_dev_cuts,
|
|
||||||
# aishell_dev_cuts,
|
|
||||||
# aishell2_dev_cuts,
|
|
||||||
# alimeeting_dev_cuts,
|
|
||||||
# magicdata_dev_cuts,
|
|
||||||
# kespeech_dev_phase1_cuts,
|
|
||||||
# kespeech_dev_phase2_cuts,
|
|
||||||
# wenetspeech_dev_cuts,
|
|
||||||
# ]
|
|
||||||
|
|
||||||
def test_cuts(self) -> Dict[str, CutSet]:
|
|
||||||
logging.info("About to get multidataset test cuts")
|
|
||||||
|
|
||||||
# Aidatatang_200zh
|
|
||||||
logging.info("Loading Aidatatang_200zh set in lazy mode")
|
|
||||||
aidatatang_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aidatatang_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
aidatatang_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aidatatang_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL
|
|
||||||
logging.info("Loading Aishell set in lazy mode")
|
|
||||||
aishell_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-2
|
|
||||||
logging.info("Loading Aishell-2 set in lazy mode")
|
|
||||||
aishell2_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
aishell2_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# AISHELL-4
|
|
||||||
logging.info("Loading Aishell-4 TEST set in lazy mode")
|
|
||||||
aishell4_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "aishell4_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ali-Meeting
|
|
||||||
logging.info("Loading Ali-Meeting set in lazy mode")
|
|
||||||
alimeeting_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
alimeeting_eval_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# MagicData
|
|
||||||
logging.info("Loading MagicData set in lazy mode")
|
|
||||||
magicdata_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
magicdata_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# KeSpeech
|
|
||||||
logging.info("Loading KeSpeech set in lazy mode")
|
|
||||||
kespeech_test_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
|
||||||
)
|
|
||||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
# WeNetSpeech
|
|
||||||
logging.info("Loading WeNetSpeech set in lazy mode")
|
|
||||||
wenetspeech_test_meeting_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_MEETING.jsonl.gz"
|
|
||||||
)
|
|
||||||
wenetspeech_test_net_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
|
||||||
)
|
|
||||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
|
||||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"aidatatang_test": aidatatang_test_cuts,
|
|
||||||
"aidatatang_dev": aidatatang_dev_cuts,
|
|
||||||
"alimeeting_test": alimeeting_test_cuts,
|
|
||||||
"alimeeting_eval": alimeeting_eval_cuts,
|
|
||||||
"aishell_test": aishell_test_cuts,
|
|
||||||
"aishell_dev": aishell_dev_cuts,
|
|
||||||
"aishell-2_test": aishell2_test_cuts,
|
|
||||||
"aishell-2_dev": aishell2_dev_cuts,
|
|
||||||
"aishell-4": aishell4_test_cuts,
|
|
||||||
"magicdata_test": magicdata_test_cuts,
|
|
||||||
"magicdata_dev": magicdata_dev_cuts,
|
|
||||||
"kespeech-asr_test": kespeech_test_cuts,
|
|
||||||
"kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
|
||||||
"kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
|
||||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
|
||||||
"wenetspeech-net_test": wenetspeech_test_net_cuts,
|
|
||||||
"wenetspeech_dev": wenetspeech_dev_cuts,
|
|
||||||
}
|
|
1
egs/multi_zh-hans/ASR/zipformer/multi_dataset.py
Symbolic link
1
egs/multi_zh-hans/ASR/zipformer/multi_dataset.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../speech_llm/ASR_LLM/whisper_llm_zh/multi_dataset.py
|
@ -539,6 +539,43 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
if normalize == "none":
|
||||||
|
return text
|
||||||
|
elif normalize == "m2met":
|
||||||
|
import re
|
||||||
|
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _to_int_tuple(s: str):
|
def _to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(",")))
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
@ -788,6 +825,9 @@ def compute_loss(
|
|||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
# remove spaces in texts
|
||||||
|
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||||
|
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
|
@ -114,6 +114,7 @@ def extract_hyp_ref_wavname(filename):
|
|||||||
for line in f:
|
for line in f:
|
||||||
if "ref" in line:
|
if "ref" in line:
|
||||||
ref = line.split("ref=")[1].strip()
|
ref = line.split("ref=")[1].strip()
|
||||||
|
if ref[0] == "[":
|
||||||
ref = ref[2:-2]
|
ref = ref[2:-2]
|
||||||
list_elements = ref.split("', '")
|
list_elements = ref.split("', '")
|
||||||
ref = "".join(list_elements)
|
ref = "".join(list_elements)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user