add speechio results

This commit is contained in:
Yuekai Zhang 2024-03-07 14:44:38 +08:00
parent b422e7a97f
commit a00c0c5279
36 changed files with 600 additions and 1518 deletions

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
def compute_fbank_aishell2(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count())
@ -69,7 +78,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
dataset_parts,
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -84,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -129,5 +140,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)

View File

@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=30
stage=1
stop_stage=1
stage=0
stop_stage=7
perturb_speed=true

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
from lhotse import LilcomChunkyWriter, CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
def compute_fbank_aishell4(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count())
@ -71,7 +80,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -87,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -140,5 +151,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=20
stop_stage=20
stage=-1
stop_stage=7
perturb_speed=true

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
def compute_fbank_alimeeting(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count())
@ -71,7 +80,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -86,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -140,5 +151,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=20
stop_stage=20
stage=-1
stop_stage=7
perturb_speed=true
# We assume dl_dir (download dir) contains the following

View File

@ -17,14 +17,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@ -32,6 +40,7 @@ from icefall.utils import str2bool
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -52,6 +61,7 @@ def get_parser():
)
return parser
def compute_fbank_kespeech_dev_test(args):
in_out_dir = Path("data/fbank/kespeech")
# number of workers in dataloader
@ -70,7 +80,9 @@ def compute_fbank_kespeech_dev_test(args):
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
)
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))

View File

@ -25,16 +25,17 @@ from pathlib import Path
import torch
from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
set_audio_duration_mismatch_tolerance,
set_caching_enabled,
)
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@ -129,7 +130,9 @@ def compute_fbank_kespeech_splits(args):
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
logging.info(f"device: {device}")

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -42,6 +49,7 @@ from icefall.utils import get_executor, str2bool
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -62,7 +70,10 @@ def get_parser():
)
return parser
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
def compute_fbank_magicdata(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/magicdata")
output_dir = Path("data/fbank")
num_jobs = min(8, os.cpu_count())
@ -86,7 +97,9 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False,
)
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -145,5 +158,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_magicdata(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
)

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -43,7 +50,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
def compute_fbank_primewords(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/primewords")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -67,7 +76,9 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -128,5 +139,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_primewords(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
)

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -43,7 +50,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
def compute_fbank_stcmds(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/stcmds")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -67,7 +76,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, wh
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -126,5 +137,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_stcmds(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
)

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -43,7 +50,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
def compute_fbank_thchs30(
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/thchs30")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -71,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, w
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -132,5 +143,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_thchs30(
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
speed_perturb=args.speed_perturb,
whisper_fbank=args.whisper_fbank,
)

View File

@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=121
stop_stage=121
stage=-1
stop_stage=100
num_splits=100
dl_dir=$PWD/download
@ -482,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
python ./local/compile_lg.py --lang-dir $lang_dir
done
fi

View File

@ -52,14 +52,14 @@ import k2
import torch
import torch.nn as nn
import whisper
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (

View File

@ -34,10 +34,10 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
--model-name medium
"""
import os
import argparse
import copy
import logging
import os
import random
import warnings
from pathlib import Path
@ -52,13 +52,13 @@ import torch.multiprocessing as mp
import torch.nn as nn
import whisper
from asr_datamodule import AsrDataModule
from multi_dataset import MultiDataset
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from multi_dataset import MultiDataset
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
@ -626,7 +626,9 @@ def train_one_epoch(
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
os.system(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -761,9 +763,7 @@ def run(rank, world_size, args):
del model.alignment_heads
if params.pretrained_model_path:
checkpoint = torch.load(
params.pretrained_model_path, map_location="cpu"
)
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:

View File

@ -0,0 +1,15 @@
# Introduction
This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Pretrained Models
The following table lists the pretrained models.
| | Huggingface | Comment |
|---------------------------------------|--------------------|-----------------------------|
| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | |
| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training |

View File

@ -0,0 +1,92 @@
## Results
### SpeechIO Test Set Decoding Results
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
Command for decoding using fine-tuned whisper:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2_wenetspeech \
--model-name large-v2 \
--epoch 999 --avg 1 \
--start-index 0 --end-index 26 \
--remove-whisper-encoder-input-length-restriction True \
--manifest-dir data/fbank \
--beam-size 1 --max-duration 50
```
Command for decoding using pretrained zipformer:
```bash
git lfs install
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
git lfs pull --include "exp/pretrained.pt"
git lfs pull --include "data/lang_bpe_2000/*"
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data
wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt
mv words.txt ./data/lang_bpe_2000/
./zipformer/decode.py \
--epoch 999 \
--avg 1 \
--blank-penalty 2.0 \
--use-averaged-model false \
--exp-dir ./zipformer/exp_pretrain \
--max-duration 600 \
--start-index 0 --end-index 26 \
--manifest-dir data/fbank_kaldi \
--decoding-method greedy_search
```
Command for fusion the above decoding results from whisper and zipformer:
```bash
python local/whisper_zipformer_fusion.py \
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
--output-log-dir ./results_fusion
```
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
SpeechIO fbank features, decoding scripts, logs, and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_speechio>

View File

@ -30,7 +30,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -44,7 +51,13 @@ torch.set_num_interop_threads(1)
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, fbank_dir: str = "data/fbank", whisper_fbank: bool = False):
def compute_fbank_speechio(
num_mel_bins: int = 80,
speed_perturb: bool = False,
fbank_dir: str = "data/fbank",
whisper_fbank: bool = False,
):
src_dir = Path("data/manifests")
output_dir = Path(fbank_dir)
num_jobs = min(8, os.cpu_count())
@ -72,7 +85,9 @@ def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False,
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -127,5 +142,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_speechio(
num_mel_bins=args.num_mel_bins, fbank_dir=args.fbank_dir, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
fbank_dir=args.fbank_dir,
whisper_fbank=args.whisper_fbank,
)

View File

@ -41,9 +41,12 @@ def main():
for partition in dataset_parts:
path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}"
cuts = load_manifest_lazy(path)
print(f"===================Duration statistics of {partition}===================")
print(
f"===================Duration statistics of {partition}==================="
)
cuts.describe()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
#
# Copyright 2024 Author: Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file uses whisper and zipformer decoding results to generate fusion decoding results.
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
Usage:
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
"""
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import kaldialign
from icefall.utils import store_transcripts, write_error_stats
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--whisper-log-dir",
type=str,
default="./recogs_whisper",
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
)
parser.add_argument(
"--zipformer-log-dir",
type=str,
default="./recogs_zipformer",
help="The directory to store the zipformer logs",
)
parser.add_argument(
"--output-log-dir",
type=str,
default="./results_fusion",
help="The directory to store the fusion logs",
)
return parser
def save_results(
res_dir: Path,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
suffix = "epoch-999-avg-1"
for key, results in results_dict.items():
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
print(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
print("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
print(s)
def extract_hyp_ref_wavname(filename):
"""
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
"""
hyps, refs, wav_name = [], [], []
with open(filename, "r") as f:
for line in f:
if "ref" in line:
ref = line.split("ref=")[1].strip()
ref = ref[2:-2]
list_elements = ref.split("', '")
ref = "".join(list_elements)
refs.append(ref)
elif "hyp" in line:
hyp = line.split("hyp=")[1].strip()
hyps.append(hyp)
wav_name.append(line.split(":")[0])
return hyps, refs, wav_name
def get_pair_filenames(
whisper_log_dir,
zipformer_log_dir,
whisper_suffix="beam-search-epoch-999-avg-1",
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
):
results = []
start_index, end_index = 0, 26
dataset_parts = []
for i in range(start_index, end_index + 1):
idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
for partition in dataset_parts:
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
zipformer_filename = (
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
)
results.append((whisper_filename, zipformer_filename))
return results
def fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs, ERR="*"
):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use left
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
"""
alignment example:
[('', ''), ('', '*'), ('', ''), ('', ''), ('', ''), ('', ''), ('*', '')]
left is whisper, right is zipformer
for whisper substitution, use left
for whisper insertion, use right
for whisper deletion, use right
"""
hyps_fusion = []
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
ali = kaldialign.align(hyp_w, hyp_z, ERR)
hyp_f = ""
for a in ali:
if a[0] == ERR:
hyp_f += a[1]
elif a[1] == ERR:
pass
else:
hyp_f += a[0]
hyps_fusion.append(hyp_f)
return hyps_fusion
def main():
parser = get_parser()
args = parser.parse_args()
# mkdir output_log_dir
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
for pair in pair_logs:
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
hyps_fusion = fusion_hyps_trust_substituion_insertion(
hyps_whisper, hyps_zipformer, refs
)
partition_name = pair[0].split("/")[-1].split("-")[1]
save_results(
Path(args.output_log_dir),
partition_name,
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
)
print(f"Processed {partition_name}")
if __name__ == "__main__":
main()

View File

@ -34,9 +34,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
)
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader

View File

@ -53,14 +53,14 @@ import k2
import torch
import torch.nn as nn
import whisper
from asr_datamodule import AsrDataModule
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert
from lhotse.cut import Cut
from multi_dataset import MultiDataset
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (

View File

@ -53,9 +53,7 @@ class MultiDataset:
path = f"{prefix}_cuts_{partition}.{suffix}"
logging.info(f"Loading {path} set in lazy mode")
test_cuts = load_manifest_lazy(
self.fbank_dir / path
)
test_cuts = load_manifest_lazy(self.fbank_dir / path)
results_dict[partition] = test_cuts
return results_dict

View File

@ -303,6 +303,17 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
@ -431,6 +442,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -455,6 +467,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
@ -468,8 +481,9 @@ def decode_one_batch(
)
hyps.append(sp.decode(hyp).split())
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
@ -657,6 +671,7 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../multi_zh-hans/ASR/zipformer/train.py

View File

@ -16,11 +16,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import argparse
import torch
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -32,6 +40,7 @@ torch.multiprocessing.set_sharing_strategy("file_system")
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -52,6 +61,7 @@ def get_parser():
)
return parser
def compute_fbank_wenetspeech_dev_test(args):
in_out_dir = Path("data/fbank")
# number of workers in dataloader
@ -66,7 +76,9 @@ def compute_fbank_wenetspeech_dev_test(args):
if torch.cuda.is_available():
device = torch.device("cuda", 0)
if args.whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
)
else:
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))

View File

@ -22,20 +22,19 @@ from datetime import datetime
from pathlib import Path
import torch
from lhotse import (
from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig,
CutSet,
WhisperFbank,
WhisperFbankConfig,
# KaldifeatWhisperFbank,
# KaldifeatWhisperFbankConfig,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
set_audio_duration_mismatch_tolerance,
set_caching_enabled,
)
from icefall.utils import str2bool, get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@ -177,13 +176,6 @@ def compute_fbank_wenetspeech_splits(args):
storage_type=LilcomChunkyWriter,
overwrite=True,
)
# cut_set = cut_set.compute_and_store_features(
# extractor=extractor,
# storage_path=f"{output_dir}/feats_{subset}_{idx}",
# num_jobs=args.num_workers,
# executor=ex,
# storage_type=LilcomChunkyWriter,
# )
logging.info(f"Saving to {cuts_path}")
cut_set.to_file(cuts_path)

View File

@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=131
stop_stage=131
stage=0
stop_stage=100
# Split L subset to this number of pieces
# This is to avoid OOM during feature extraction.

View File

@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse
import logging
from icefall import is_module_available
import torch
from onnx_pretrained import OnnxModel
import torch
from icefall import is_module_available
def get_parser():

View File

@ -52,13 +52,13 @@ import k2
import torch
import torch.nn as nn
import whisper
from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert
from lhotse.cut import Cut
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (

View File

@ -834,6 +834,7 @@ def run(rank, world_size, args):
# )
return False
return True
train_cuts = wenetspeech.train_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = wenetspeech.train_dataloaders(train_cuts)