mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
update speechio whisper ft results
This commit is contained in:
parent
df36f93bd8
commit
b970ba569a
@ -1,5 +1,48 @@
|
||||
## Results
|
||||
|
||||
### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
|
||||
#### Whisper
|
||||
[./whisper](./whisper)
|
||||
|
||||
Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.
|
||||
|
||||
| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||
| Greedy Search | 23.45 | 25.42 | 0.78 | 0.83 | 2.75 | 2.93 | 17.11 | 2.68 | 2.33 | 4.97 | 2.02 | 6.34 | 5.06 | 8.38 | 6.94 |
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
pip install -r whisper/requirements.txt
|
||||
|
||||
# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
```
|
||||
|
||||
Command for decoding using fine-tuned models:
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
|
||||
ln -s icefall_asr_multi-hans-zh_whisper/v1/epoch-2-avg4.pt whisper/exp_large_v2/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--beam-size 10 --max-duration 50
|
||||
```
|
||||
|
||||
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>
|
||||
|
||||
|
||||
### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model
|
||||
|
||||
This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
|
||||
|
@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Prepare WenetSpeech"
|
||||
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
|
||||
cd data/fbank
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
||||
|
||||
@ -299,15 +299,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Compute KeSpeech fbank for test/dev"
|
||||
./local/compute_fbank_kespeech_dev_test.py
|
||||
|
||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
fi
|
||||
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||
fi
|
||||
|
||||
touch data/fbank/.kespeech.done
|
||||
fi
|
||||
fi
|
||||
|
48
egs/multi_zh-hans/ASR/whisper/decode.py
Normal file → Executable file
48
egs/multi_zh-hans/ASR/whisper/decode.py
Normal file → Executable file
@ -58,6 +58,7 @@ from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
@ -214,7 +215,7 @@ def get_parser():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
@ -226,6 +227,13 @@ def get_parser():
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-distill-whisper",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use architecture of distill whisper.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -289,7 +297,6 @@ def decode_one_batch(
|
||||
print(hyps)
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
@ -307,6 +314,40 @@ def decode_dataset(
|
||||
Returns:
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
if normalize == "none":
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
@ -331,6 +372,7 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_text = normalize_text_alimeeting(ref_text)
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
@ -430,6 +472,8 @@ def main():
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
if params.use_distill_whisper:
|
||||
replace_whisper_decoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
|
71
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Normal file → Executable file
71
egs/multi_zh-hans/ASR/whisper/multi_dataset.py
Normal file → Executable file
@ -43,7 +43,7 @@ class MultiDataset:
|
||||
- thchs_30_cuts_train.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
|
||||
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
|
||||
- wenetspeech/cuts_L.jsonl.gz
|
||||
- wenetspeech/cuts_L_fixed.jsonl.gz
|
||||
"""
|
||||
self.fbank_dir = Path(fbank_dir)
|
||||
|
||||
@ -105,7 +105,7 @@ class MultiDataset:
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech in lazy mode")
|
||||
wenetspeech_L_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
|
||||
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
@ -124,10 +124,10 @@ class MultiDataset:
|
||||
aishell_4_L_cuts,
|
||||
aishell_4_M_cuts,
|
||||
aishell_4_S_cuts,
|
||||
alimeeting_cuts,
|
||||
stcmds_cuts,
|
||||
primewords_cuts,
|
||||
magicdata_cuts,
|
||||
alimeeting_cuts,
|
||||
wenetspeech_L_cuts,
|
||||
kespeech_1_cuts,
|
||||
kespeech_2_cuts,
|
||||
@ -138,10 +138,10 @@ class MultiDataset:
|
||||
len(aishell_4_L_cuts),
|
||||
len(aishell_4_M_cuts),
|
||||
len(aishell_4_S_cuts),
|
||||
len(alimeeting_cuts),
|
||||
len(stcmds_cuts),
|
||||
len(primewords_cuts),
|
||||
len(magicdata_cuts),
|
||||
len(alimeeting_cuts),
|
||||
len(wenetspeech_L_cuts),
|
||||
len(kespeech_1_cuts),
|
||||
len(kespeech_2_cuts),
|
||||
@ -151,55 +151,13 @@ class MultiDataset:
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
|
||||
# AISHELL
|
||||
logging.info("Loading Aishell DEV set in lazy mode")
|
||||
aishell_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 DEV set in lazy mode")
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# Ali-Meeting
|
||||
logging.info("Loading Ali-Meeting DEV set in lazy mode")
|
||||
alimeeting_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
|
||||
)
|
||||
|
||||
# MagicData
|
||||
logging.info("Loading MagicData DEV set in lazy mode")
|
||||
magicdata_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# KeSpeech
|
||||
logging.info("Loading KeSpeech DEV set in lazy mode")
|
||||
kespeech_dev_phase1_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
|
||||
)
|
||||
kespeech_dev_phase2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
|
||||
)
|
||||
|
||||
# WeNetSpeech
|
||||
logging.info("Loading WeNetSpeech DEV set in lazy mode")
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
return wenetspeech_dev_cuts
|
||||
# return [
|
||||
# aishell_dev_cuts,
|
||||
# aishell2_dev_cuts,
|
||||
# alimeeting_dev_cuts,
|
||||
# magicdata_dev_cuts,
|
||||
# kespeech_dev_phase1_cuts,
|
||||
# kespeech_dev_phase2_cuts,
|
||||
# wenetspeech_dev_cuts,
|
||||
# ]
|
||||
|
||||
def test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
@ -267,30 +225,23 @@ class MultiDataset:
|
||||
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
|
||||
)
|
||||
wenetspeech_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
|
||||
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"aishell-2_test": aishell2_test_cuts,
|
||||
"aishell-4": aishell4_test_cuts,
|
||||
"magicdata_test": magicdata_test_cuts,
|
||||
"kespeech-asr_test": kespeech_test_cuts,
|
||||
}
|
||||
|
||||
# return {
|
||||
# "alimeeting_test": alimeeting_test_cuts,
|
||||
# "alimeeting_eval": alimeeting_eval_cuts,
|
||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||
# "aishell_test": aishell_test_cuts,
|
||||
# "aishell_dev": aishell_dev_cuts,
|
||||
# "ali-meeting_test": alimeeting_test_cuts,
|
||||
# "ali-meeting_eval": alimeeting_eval_cuts,
|
||||
# "aishell-4_test": aishell4_test_cuts,
|
||||
# "aishell-2_test": aishell2_test_cuts,
|
||||
# "aishell-2_dev": aishell2_dev_cuts,
|
||||
# "aishell-4": aishell4_test_cuts,
|
||||
# "magicdata_test": magicdata_test_cuts,
|
||||
# "magicdata_dev": magicdata_dev_cuts,
|
||||
# "kespeech-asr_test": kespeech_test_cuts,
|
||||
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
|
||||
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
|
||||
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
|
||||
# "wenetspeech_dev": wenetspeech_dev_cuts,
|
||||
# }
|
||||
}
|
51
egs/multi_zh-hans/ASR/whisper/train.py
Normal file → Executable file
51
egs/multi_zh-hans/ASR/whisper/train.py
Normal file → Executable file
@ -66,6 +66,7 @@ from torch.nn.functional import pad as pad_tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
@ -146,7 +147,7 @@ def get_parser():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
@ -232,6 +233,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-distill-whisper",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use architecture of distill whisper.",
|
||||
)
|
||||
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -441,6 +449,41 @@ def compute_loss(
|
||||
padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
|
||||
return torch.stack([tensor for tensor in padded_tensors], dim=0)
|
||||
|
||||
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
if normalize == "none":
|
||||
return text
|
||||
elif normalize == "m2met":
|
||||
import re
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
max_frames = params.max_duration * 1000 // params.frame_shift_ms
|
||||
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
|
||||
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
|
||||
@ -459,7 +502,7 @@ def compute_loss(
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
texts = [normalize_text_alimeeting(text) for text in texts]
|
||||
|
||||
text_tokens_list = [
|
||||
list(tokenizer.sot_sequence_including_notimestamps)
|
||||
@ -759,6 +802,8 @@ def run(rank, world_size, args):
|
||||
logging.info("About to create model")
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
if params.use_distill_whisper:
|
||||
replace_whisper_decoder_forward()
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
del model.alignment_heads
|
||||
|
||||
@ -824,7 +869,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
512
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from typing import Dict, Iterable, Optional
|
||||
from whisper.model import ResidualAttentionBlock, LayerNorm
|
||||
import numpy as np
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = (
|
||||
x
|
||||
+ self.positional_embedding[offset : offset + x.shape[1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
# for block in self.blocks:
|
||||
# x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
# use architecture from the distill whisper model
|
||||
# see https://github.com/huggingface/distil-whisper
|
||||
x = self.blocks[0](x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
x = self.blocks[-1](x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
def replace_whisper_decoder_forward():
|
||||
"""
|
||||
This function monkey patches the forward method of the whisper encoder.
|
||||
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
|
||||
"""
|
||||
whisper.model.TextDecoder.forward = forward
|
@ -2,41 +2,70 @@
|
||||
|
||||
### SpeechIO Test Set Decoding Results
|
||||
|
||||
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
|
||||
|
||||
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|
||||
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
|
||||
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
|
||||
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
|
||||
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
|
||||
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
|
||||
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
|
||||
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
|
||||
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
|
||||
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
|
||||
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
|
||||
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
|
||||
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
|
||||
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
|
||||
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
|
||||
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
|
||||
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
|
||||
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
|
||||
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
|
||||
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
|
||||
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
|
||||
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
|
||||
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
|
||||
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
|
||||
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
|
||||
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
|
||||
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
|
||||
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
|
||||
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
|
||||
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
|
||||
|
||||
|
||||
|
||||
#### **Unlocked** SpeechIO test sets (ZH00001 ~ ZH00026)
|
||||
| Rank 排名 | Model 模型 | CER 字错误率 | Date 时间 |
|
||||
| --- | --- | --- | --- |
|
||||
| 1 | ximalaya_api_zh | 1.72% | 2023.12 |
|
||||
| 2 | aliyun_ftasr_api_zh | 1.85% | 2023.12 |
|
||||
| 3 | microsoft_batch_zh | 2.40% | 2023.12 |
|
||||
| 4 | bilibili_api_zh | 2.90% | 2023.09 |
|
||||
| 5 | tencent_api_zh | 3.18% | 2023.12 |
|
||||
| 6 | iflytek_lfasr_api_zh | 3.32% | 2023.12 |
|
||||
| 7 | aispeech_api_zh | 3.62% | 2023.12 |
|
||||
| 8 | **whisper-large-ft-v1** | **4.45%** | 2024.04 |
|
||||
| 9 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
|
||||
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
|
||||
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |
|
||||
|
||||
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
|
||||
|
||||
<details><summary> Detail all models </summary><p>
|
||||
|
||||
| Model | Training Set | Note |
|
||||
|----------------------------------------------------------------------------------------------------------|---------------|-----------------------------------------------------|
|
||||
|[zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24)| multi-hans-zh | decoding with transducer head and blank penalty 2.0 |
|
||||
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper)| wenetspeech | |
|
||||
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py)|
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details><summary> Detail all results (字错误率 CER %) </summary><p>
|
||||
|
||||
| Test Set ID | 测试场景&内容领域 | bilibili_api_zh (2023.09) | whisper-large-ft-v0 | whisper-large-ft-v1 | zipformer |
|
||||
|----------------------|-------------------------------|-----------------|---------|-----------|-----------|
|
||||
| Avg (01-26) | | 2.9 | 6.34 | 4.45 | 6.17 |
|
||||
| SPEECHIO_ASR_ZH00001 | 新闻联播 | 0.54 | 1.42 | 1.11 | 1.37 |
|
||||
| SPEECHIO_ASR_ZH00002 | 访谈 鲁豫有约 | 2.78 | 4.76 | 3.25 | 4.67 |
|
||||
| SPEECHIO_ASR_ZH00003 | 电视节目 天下足球 | 0.81 | 2.17 | 1.64 | 2.71 |
|
||||
| SPEECHIO_ASR_ZH00004 | 场馆演讲 罗振宇跨年 | 1.48 | 2.53 | 1.91 | 2.54 |
|
||||
| SPEECHIO_ASR_ZH00005 | 在线教育 李永乐 科普 | 1.47 | 4.27 | 1.98 | 3.12 |
|
||||
| SPEECHIO_ASR_ZH00006 | 直播 王者荣耀 张大仙&骚白 | 5.85 | 12.55 | 9.28 | 12.86 |
|
||||
| SPEECHIO_ASR_ZH00007 | 直播 带货 李佳琪&薇娅 | 6.19 | 13.38 | 11.14 | 14.58 |
|
||||
| SPEECHIO_ASR_ZH00008 | 线下培训 老罗语录 | 3.68 | 9.56 | 7.47 | 9.05 |
|
||||
| SPEECHIO_ASR_ZH00009 | 播客 故事FM | 3.18 | 5.66 | 3.78 | 5.4 |
|
||||
| SPEECHIO_ASR_ZH00010 | 播客 创业内幕 | 3.51 | 7.84 | 4.32 | 6.4 |
|
||||
| SPEECHIO_ASR_ZH00011 | 在线教育 罗翔 刑法法考 | 1.77 | 3.22 | 2.49 | 3.12 |
|
||||
| SPEECHIO_ASR_ZH00012 | 在线教育 张雪峰 考研 | 2.11 | 5.98 | 3.09 | 4.41 |
|
||||
| SPEECHIO_ASR_ZH00013 | 短视频 影剪 谷阿莫&牛叔说电影 | 2.97 | 5.91 | 3.78 | 6.56 |
|
||||
| SPEECHIO_ASR_ZH00014 | 短视频 美式&烹饪 | 3.56 | 6.03 | 5.28 | 8.14 |
|
||||
| SPEECHIO_ASR_ZH00015 | 评书 单田芳 白眉大侠 | 4.72 | 8.77 | 7.97 | 9.1 |
|
||||
| SPEECHIO_ASR_ZH00016 | 相声 德云社专场 | 3.01 | 5.24 | 4.41 | 5.59 |
|
||||
| SPEECHIO_ASR_ZH00017 | 脱口秀 吐槽大会 | 2.93 | 7.05 | 3.27 | 5.17 |
|
||||
| SPEECHIO_ASR_ZH00018 | 少儿卡通 小猪佩奇&熊出没 | 1.98 | 3.53 | 3.21 | 4.15 |
|
||||
| SPEECHIO_ASR_ZH00019 | 体育赛事解说 NBA比赛 | 2.32 | 6.89 | 4.24 | 6.66 |
|
||||
| SPEECHIO_ASR_ZH00020 | 纪录片 篮球人物 | 1.51 | 4.16 | 2.96 | 4.2 |
|
||||
| SPEECHIO_ASR_ZH00021 | 短视频 汽车之家 汽车评测 | 1.75 | 4.77 | 2.77 | 4.17 |
|
||||
| SPEECHIO_ASR_ZH00022 | 短视频 小艾大叔 豪宅带看 | 3.29 | 6.35 | 5.66 | 6.72 |
|
||||
| SPEECHIO_ASR_ZH00023 | 短视频 开箱视频 Zeal&无聊开箱 | 2.18 | 8.99 | 4.45 | 7.94 |
|
||||
| SPEECHIO_ASR_ZH00024 | 短视频 付老师 农业种植 | 4.80 | 10.81 | 6.25 | 8.64 |
|
||||
| SPEECHIO_ASR_ZH00025 | 线下课堂 石国鹏 古希腊哲学 | 3.32 | 8.41 | 5.8 | 8.54 |
|
||||
| SPEECHIO_ASR_ZH00026 | 广播电台节目 张震鬼故事 | 3.70 | 4.52 | 4.11 | 4.67 |
|
||||
</details>
|
||||
|
||||
|
||||
Command for decoding using fine-tuned whisper:
|
||||
```bash
|
||||
@ -76,16 +105,6 @@ mv words.txt ./data/lang_bpe_2000/
|
||||
--manifest-dir data/fbank_kaldi \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
Command for fusion the above decoding results from whisper and zipformer:
|
||||
```bash
|
||||
python local/whisper_zipformer_fusion.py \
|
||||
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
|
||||
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
|
||||
--output-log-dir ./results_fusion
|
||||
|
||||
```
|
||||
|
||||
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
|
||||
|
||||
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
||||
are available at
|
||||
|
104
egs/speechio/ASR/local/whisper_zipformer_fusion.py → egs/speechio/ASR/local/normalize_results.py
Normal file → Executable file
104
egs/speechio/ASR/local/whisper_zipformer_fusion.py → egs/speechio/ASR/local/normalize_results.py
Normal file → Executable file
@ -21,7 +21,7 @@ Since whisper model is more likely to make deletion errors and zipformer model i
|
||||
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
||||
|
||||
Usage:
|
||||
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
|
||||
python whisper_zipformer_fusion.py --model-log-dir ./whisper_decoding_log_dir --output-log-dir ./results_norm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -31,38 +31,43 @@ from typing import Dict, List, Optional, Tuple
|
||||
import kaldialign
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
from speechio_norm import TextNorm
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-log-dir",
|
||||
"--model-log-dir",
|
||||
type=str,
|
||||
default="./recogs_whisper",
|
||||
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zipformer-log-dir",
|
||||
type=str,
|
||||
default="./recogs_zipformer",
|
||||
help="The directory to store the zipformer logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-log-dir",
|
||||
type=str,
|
||||
default="./results_fusion",
|
||||
help="The directory to store the fusion logs",
|
||||
default="./results_whisper_norm",
|
||||
help="The directory to store the normalized whisper logs",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def save_results(
|
||||
def save_results_with_speechio_text_norm(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
normalizer = TextNorm()
|
||||
# normlize items in results_dict
|
||||
for key, results in results_dict.items():
|
||||
results_norm = []
|
||||
for item in results:
|
||||
wav_name, ref, hyp = item
|
||||
ref = normalizer(ref)
|
||||
hyp = normalizer(hyp)
|
||||
results_norm.append((wav_name, ref, hyp))
|
||||
results_dict[key] = results_norm
|
||||
|
||||
test_set_wers = dict()
|
||||
|
||||
suffix = "epoch-999-avg-1"
|
||||
@ -120,11 +125,9 @@ def extract_hyp_ref_wavname(filename):
|
||||
return hyps, refs, wav_name
|
||||
|
||||
|
||||
def get_pair_filenames(
|
||||
def get_filenames(
|
||||
whisper_log_dir,
|
||||
zipformer_log_dir,
|
||||
whisper_suffix="beam-search-epoch-999-avg-1",
|
||||
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
|
||||
):
|
||||
results = []
|
||||
start_index, end_index = 0, 26
|
||||
@ -134,80 +137,23 @@ def get_pair_filenames(
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
for partition in dataset_parts:
|
||||
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
||||
zipformer_filename = (
|
||||
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
|
||||
)
|
||||
results.append((whisper_filename, zipformer_filename))
|
||||
results.append(whisper_filename)
|
||||
return results
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs, ERR="*"
|
||||
):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use left
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use right
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
elif a[1] == ERR:
|
||||
pass
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
# mkdir output_log_dir
|
||||
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
||||
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
|
||||
for pair in pair_logs:
|
||||
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
|
||||
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
|
||||
filenames = get_filenames(args.model_log_dir)
|
||||
for filename in filenames:
|
||||
hyps, refs, wav_name = extract_hyp_ref_wavname(filename)
|
||||
partition_name = filename.split("/")[-1].split("-")[1]
|
||||
|
||||
hyps_fusion = fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs
|
||||
)
|
||||
|
||||
partition_name = pair[0].split("/")[-1].split("-")[1]
|
||||
save_results(
|
||||
save_results_with_speechio_text_norm(
|
||||
Path(args.output_log_dir),
|
||||
partition_name,
|
||||
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
|
||||
{"norm": list(zip(wav_name, refs, hyps))},
|
||||
)
|
||||
|
||||
print(f"Processed {partition_name}")
|
1203
egs/speechio/ASR/local/speechio_norm.py
Executable file
1203
egs/speechio/ASR/local/speechio_norm.py
Executable file
File diff suppressed because it is too large
Load Diff
114
egs/wenetspeech/ASR/local/fix_manifest.py
Normal file
114
egs/wenetspeech/ASR/local/fix_manifest.py
Normal file
@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 author: Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import argparse
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fixed-transcript-path",
|
||||
type=str,
|
||||
default="data/fbank/text.fix",
|
||||
help="""
|
||||
See https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=str,
|
||||
default="data/fbank/",
|
||||
help="Directory to store the manifest files",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--training-subset",
|
||||
type=str,
|
||||
default="L",
|
||||
help="The training subset for wenetspeech.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def load_fixed_text(fixed_text_path):
|
||||
"""
|
||||
fixed text format
|
||||
X0000016287_92761015_S00001 我是徐涛
|
||||
X0000016287_92761015_S00002 狄更斯的PICK WEEK PAPERS斯
|
||||
load into a dict
|
||||
"""
|
||||
fixed_text_dict = {}
|
||||
with open(fixed_text_path, 'r') as f:
|
||||
for line in f:
|
||||
cut_id, text = line.strip().split(' ', 1)
|
||||
fixed_text_dict[cut_id] = text
|
||||
return fixed_text_dict
|
||||
|
||||
def fix_manifest(manifest, fixed_text_dict, fixed_manifest_path):
|
||||
with CutSet.open_writer(fixed_manifest_path) as manifest_writer:
|
||||
fixed_item = 0
|
||||
for i, cut in enumerate(manifest):
|
||||
if i % 10000 == 0:
|
||||
logging.info(f'Processing cut {i}, fixed {fixed_item}')
|
||||
cut_id_orgin = cut.id
|
||||
if cut_id_orgin.endswith('_sp0.9'):
|
||||
cut_id = cut_id_orgin[:-6]
|
||||
elif cut_id_orgin.endswith('_sp1.1'):
|
||||
cut_id = cut_id_orgin[:-6]
|
||||
else:
|
||||
cut_id = cut_id_orgin
|
||||
if cut_id in fixed_text_dict:
|
||||
assert len(cut.supervisions) == 1, f'cut {cut_id} has {len(cut.supervisions)} supervisions'
|
||||
if cut.supervisions[0].text != fixed_text_dict[cut_id]:
|
||||
logging.info(f'Fixed text for cut {cut_id_orgin} from {cut.supervisions[0].text} to {fixed_text_dict[cut_id]}')
|
||||
cut.supervisions[0].text = fixed_text_dict[cut_id]
|
||||
fixed_item += 1
|
||||
manifest_writer.write(cut)
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
fixed_text_path = args.manifest_dir + 'text.fix'
|
||||
fixed_text_dict = load_fixed_text(fixed_text_path)
|
||||
logging.info(f'Loaded {len(fixed_text_dict)} fixed texts')
|
||||
|
||||
dev_manifest_path = args.manifest_dir + 'cuts_DEV.jsonl.gz'
|
||||
fixed_dev_manifest_path = args.manifest_dir + 'cuts_DEV_fixed.jsonl.gz'
|
||||
logging.info(f'Loading dev manifest from {dev_manifest_path}')
|
||||
cuts_dev_manifest = load_manifest_lazy(dev_manifest_path)
|
||||
fix_manifest(cuts_dev_manifest, fixed_text_dict, fixed_dev_manifest_path)
|
||||
logging.info(f'Fixed dev manifest saved to {fixed_dev_manifest_path}')
|
||||
|
||||
manifest_path = args.manifest_dir + f'cuts_{args.training_subset}.jsonl.gz'
|
||||
fixed_manifest_path = args.manifest_dir + f'cuts_{args.training_subset}_fixed.jsonl.gz'
|
||||
logging.info(f'Loading manifest from {manifest_path}')
|
||||
cuts_manifest = load_manifest_lazy(manifest_path)
|
||||
fix_manifest(cuts_manifest, fixed_text_dict, fixed_manifest_path)
|
||||
logging.info(f'Fixed training manifest saved to {fixed_manifest_path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -416,3 +416,12 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
|
||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 23 ] && [ $stop_stage -ge 23 ]; then
|
||||
log "Stage 23: Modify transcript according to fixed results"
|
||||
# See https://github.com/wenet-e2e/WenetSpeech/discussions/54
|
||||
wget -nc https://huggingface.co/datasets/yuekai/wenetspeech_paraformer_fixed_transcript/resolve/main/text.fix -O data/fbank/text.fix
|
||||
python local/fix_manifest.py \
|
||||
--fixed-transcript-path data/fbank/text.fix \
|
||||
--training-subset L
|
||||
fi
|
@ -390,14 +390,14 @@ class WenetSpeechAsrDataModule:
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||
self.args.manifest_dir / f"cuts_{self.args.training_subset}_fixed.jsonl.gz"
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV_fixed.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_net_cuts(self) -> List[CutSet]:
|
||||
|
@ -44,6 +44,7 @@ from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
@ -145,7 +146,7 @@ def get_parser():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
@ -616,7 +617,9 @@ def train_one_epoch(
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -803,7 +806,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
512
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
@ -893,6 +896,7 @@ def run(rank, world_size, args):
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
)
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
else:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
|
Loading…
x
Reference in New Issue
Block a user