mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
clean infer codes
This commit is contained in:
parent
3ba6febe4f
commit
03d500a414
@ -1,45 +1,81 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# import bigvan
|
||||
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from accelerate import Accelerator
|
||||
|
||||
# from importlib.resources import files
|
||||
# import sys
|
||||
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
|
||||
# from bigvgan import BigVGAN
|
||||
from bigvganinference import BigVGANInference
|
||||
|
||||
# from f5_tts.eval.utils_eval import (
|
||||
# get_inference_prompt,
|
||||
# get_librispeech_test_clean_metainfo,
|
||||
# get_seedtts_testset_metainfo,
|
||||
# )
|
||||
# from f5_tts.infer.utils_infer import load_vocoder
|
||||
from model.cfm import CFM
|
||||
from model.dit import DiT
|
||||
from model.modules import MelSpec
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from tqdm import tqdm
|
||||
from train import get_tokenizer, load_pretrained_checkpoint
|
||||
from train import (
|
||||
add_model_arguments,
|
||||
get_model,
|
||||
get_tokenizer,
|
||||
load_F5_TTS_pretrained_checkpoint,
|
||||
)
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def load_vocoder(device):
|
||||
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||
model = BigVGANInference.from_pretrained(
|
||||
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
return model
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="f5-tts/vocab.txt",
|
||||
help="Path to the unique text tokens file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
|
||||
help="Path to the unique text tokens file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nfe",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The number of steps for the neural ODE",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=str,
|
||||
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst",
|
||||
help="The manifest file in seed_tts_eval format",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default="results",
|
||||
help="The output directory to save the generated wavs",
|
||||
)
|
||||
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
add_model_arguments(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_inference_prompt(
|
||||
@ -52,7 +88,7 @@ def get_inference_prompt(
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="vocos",
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
@ -209,151 +245,54 @@ def get_seedtts_testset_metainfo(metalst):
|
||||
f.close()
|
||||
metainfo = []
|
||||
for line in lines:
|
||||
if len(line.strip().split("|")) == 5:
|
||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
||||
elif len(line.strip().split("|")) == 4:
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||
assert len(line.strip().split("|")) == 4
|
||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||
utt = Path(utt).stem
|
||||
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||
if not os.path.isabs(prompt_wav):
|
||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
||||
return metainfo
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
|
||||
# --------------------- Dataset Settings -------------------- #
|
||||
|
||||
target_sample_rate = 24000
|
||||
n_mel_channels = 100
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_fft = 1024
|
||||
target_rms = 0.1
|
||||
|
||||
# rel_path = str(files("f5_tts").joinpath("../../"))
|
||||
|
||||
|
||||
def main():
|
||||
# ---------------------- infer setting ---------------------- #
|
||||
args = get_parser()
|
||||
|
||||
parser = argparse.ArgumentParser(description="batch inference")
|
||||
|
||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
||||
parser.add_argument("-n", "--expname", required=True)
|
||||
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mel_spec_type",
|
||||
default="bigvgan",
|
||||
type=str,
|
||||
choices=["bigvgan", "vocos"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
|
||||
)
|
||||
|
||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
||||
parser.add_argument("-o", "--odemethod", default="euler")
|
||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||
|
||||
parser.add_argument("-t", "--testset", required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
dataset_name = args.dataset
|
||||
exp_name = args.expname
|
||||
ckpt_step = args.ckptstep
|
||||
|
||||
ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
|
||||
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
|
||||
|
||||
mel_spec_type = args.mel_spec_type
|
||||
tokenizer = args.tokenizer
|
||||
|
||||
nfe_step = args.nfestep
|
||||
ode_method = args.odemethod
|
||||
sway_sampling_coef = args.swaysampling
|
||||
|
||||
testset = args.testset
|
||||
|
||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
||||
cfg_strength = 2.0
|
||||
speed = 1.0
|
||||
use_truth_duration = False
|
||||
no_ref_audio = False
|
||||
|
||||
model_cls = DiT
|
||||
model_cfg = dict(
|
||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
||||
)
|
||||
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
|
||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
||||
|
||||
# path to save genereted wavs
|
||||
output_dir = (
|
||||
f"./"
|
||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
||||
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
||||
f"_cfg{cfg_strength}_speed{speed}"
|
||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
||||
)
|
||||
accelerator = Accelerator()
|
||||
device = f"cuda:{accelerator.process_index}"
|
||||
|
||||
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||
prompts_all = get_inference_prompt(
|
||||
metainfo,
|
||||
speed=speed,
|
||||
tokenizer=tokenizer,
|
||||
target_sample_rate=target_sample_rate,
|
||||
n_mel_channels=n_mel_channels,
|
||||
hop_length=hop_length,
|
||||
mel_spec_type=mel_spec_type,
|
||||
target_rms=target_rms,
|
||||
use_truth_duration=use_truth_duration,
|
||||
infer_batch_size=infer_batch_size,
|
||||
speed=1.0,
|
||||
tokenizer="pinyin",
|
||||
target_sample_rate=24_000,
|
||||
n_mel_channels=100,
|
||||
hop_length=256,
|
||||
mel_spec_type="bigvgan",
|
||||
target_rms=0.1,
|
||||
use_truth_duration=False,
|
||||
infer_batch_size=1,
|
||||
)
|
||||
|
||||
vocoder = load_vocoder(device)
|
||||
|
||||
# Tokenizer
|
||||
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")
|
||||
|
||||
# Model
|
||||
model = CFM(
|
||||
transformer=model_cls(
|
||||
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
||||
),
|
||||
mel_spec_kwargs=dict(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
target_sample_rate=target_sample_rate,
|
||||
mel_spec_type=mel_spec_type,
|
||||
),
|
||||
odeint_kwargs=dict(
|
||||
method=ode_method,
|
||||
),
|
||||
vocab_char_map=vocab_char_map,
|
||||
).to(device)
|
||||
|
||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
||||
# model = load_pretrained_checkpoint(model, ckpt_path)
|
||||
_ = load_checkpoint(
|
||||
ckpt_path,
|
||||
model=model,
|
||||
vocoder = BigVGANInference.from_pretrained(
|
||||
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||
)
|
||||
model = model.eval().to(device)
|
||||
vocoder = vocoder.eval().to(device)
|
||||
|
||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
||||
os.makedirs(output_dir)
|
||||
model = get_model(args).eval().to(device)
|
||||
checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
if "model_state_dict" or "ema_model_state_dict" in checkpoint:
|
||||
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
|
||||
else:
|
||||
_ = load_checkpoint(
|
||||
args.model_path,
|
||||
model=model,
|
||||
)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# start batch inference
|
||||
accelerator.wait_for_everyone()
|
||||
start = time.time()
|
||||
|
||||
@ -378,25 +317,23 @@ def main():
|
||||
text=final_text_list,
|
||||
duration=total_mel_lens,
|
||||
lens=ref_mel_lens,
|
||||
steps=nfe_step,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coef=sway_sampling_coef,
|
||||
no_ref_audio=no_ref_audio,
|
||||
seed=seed,
|
||||
steps=args.nfe,
|
||||
cfg_strength=2.0,
|
||||
sway_sampling_coef=args.swaysampling,
|
||||
no_ref_audio=False,
|
||||
seed=args.seed,
|
||||
)
|
||||
# Final result
|
||||
for i, gen in enumerate(generated):
|
||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||
if mel_spec_type == "vocos":
|
||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
||||
elif mel_spec_type == "bigvgan":
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
|
||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||
target_rms = 0.1
|
||||
target_sample_rate = 24_000
|
||||
if ref_rms_list[i] < target_rms:
|
||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||
torchaudio.save(
|
||||
f"{output_dir}/{utts[i]}.wav",
|
||||
f"{args.output_dir}/{utts[i]}.wav",
|
||||
generated_wave,
|
||||
target_sample_rate,
|
||||
)
|
||||
@ -408,4 +345,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -409,7 +409,7 @@ def get_model(params):
|
||||
return model
|
||||
|
||||
|
||||
def load_pretrained_checkpoint(
|
||||
def load_F5_TTS_pretrained_checkpoint(
|
||||
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
||||
):
|
||||
# model = model.to(dtype)
|
||||
@ -937,7 +937,7 @@ def run(rank, world_size, args):
|
||||
logging.info("About to create model")
|
||||
|
||||
model = get_model(params)
|
||||
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
# model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
model = model.to(device)
|
||||
|
||||
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
||||
|
@ -1,3 +1,15 @@
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
#bigvganinference
|
||||
model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||
manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt
|
||||
manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst
|
||||
# get wenetspeech4tts
|
||||
manifest_base_stem=$(basename $manifest)
|
||||
mainfest_base_stem=${manifest_base_stem%.*}
|
||||
output_dir=./results/f5-tts-pretrained/$mainfest_base_stem
|
||||
|
||||
accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
||||
|
||||
pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1
|
||||
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
|
27
egs/wenetspeech4tts/TTS/local/compute_wer.sh
Normal file
27
egs/wenetspeech4tts/TTS/local/compute_wer.sh
Normal file
@ -0,0 +1,27 @@
|
||||
wav_dir=$1
|
||||
wav_files=$(ls $wav_dir/*.wav)
|
||||
# wav_files=$(echo $wav_files | cut -d " " -f 1)
|
||||
# if wav_files is empty, then exit
|
||||
if [ -z "$wav_files" ]; then
|
||||
exit 1
|
||||
fi
|
||||
label_file=$2
|
||||
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
|
||||
if [ ! -d $model_path ]; then
|
||||
pip install sherpa-onnx
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||
fi
|
||||
|
||||
python3 local/offline-decode-files.py \
|
||||
--tokens=$model_path/tokens.txt \
|
||||
--paraformer=$model_path/model.int8.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=24000 \
|
||||
--log-dir $wav_dir \
|
||||
--feature-dim=80 \
|
||||
--label $label_file \
|
||||
$wav_files
|
495
egs/wenetspeech4tts/TTS/local/offline-decode-files.py
Executable file
495
egs/wenetspeech4tts/TTS/local/offline-decode-files.py
Executable file
@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright (c) 2023 by manyeyes
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||
file(s) with a non-streaming model.
|
||||
|
||||
(1) For paraformer
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/paraformer.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(2) For transducer models from icefall
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(3) For CTC models from NeMo
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||
|
||||
(4) For Whisper models
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||
--whisper-task=transcribe \
|
||||
--num-threads=1 \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||
|
||||
(5) For CTC models from WeNet
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||
|
||||
(6) For tdnn models of the yesno recipe from icefall
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--sample-rate=8000 \
|
||||
--feature-dim=23 \
|
||||
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||
used in this file.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hotwords-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The file containing hotwords, one words/phrases per line, like
|
||||
HELLO WORLD
|
||||
你好世界
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hotwords-score",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="""
|
||||
The hotword score of each token for biasing word/phrase. Used only if
|
||||
--hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
Used only when hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-vocab",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||
and modeling-unit is bpe or cjkchar+bpe.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--paraformer",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nemo-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tdnn-model",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-language",
|
||||
default="",
|
||||
type=str,
|
||||
help="""It specifies the spoken language in the input audio file.
|
||||
Example values: en, fr, de, zh, jp.
|
||||
Available languages for multilingual models can be found at
|
||||
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
If not specified, we infer the language from the input audio file.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-task",
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
help="""For multilingual models, if you specify translate, the output
|
||||
will be in English.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-tail-paddings",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="""Number of tail padding frames.
|
||||
We have removed the 30-second constraint from whisper, so you need to
|
||||
choose the amount of tail padding frames by yourself.
|
||||
Use -1 to use a default value for tail padding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="True to show debug messages",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="""Sample rate of the feature extractor. Must match the one
|
||||
expected by the model. Note: The input sound files can have a
|
||||
different sample rate from this argument.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature-dim",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Feature dimension. Must match the one expected by the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory containing the input sound files to decode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory containing the input sound files to decode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default=None,
|
||||
help="wav_base_name label",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and can be of type
|
||||
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples,
|
||||
which are normalized to the range [-1, 1].
|
||||
- Sample rate of the wave file.
|
||||
"""
|
||||
|
||||
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||
assert (
|
||||
samples.ndim == 1
|
||||
), f"Expected single channel, but got {samples.ndim} channels."
|
||||
|
||||
samples_float32 = samples.astype(np.float32)
|
||||
|
||||
return samples_float32, sample_rate
|
||||
|
||||
|
||||
def normalize_text_alimeeting(text: str) -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
import re
|
||||
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
return text
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=args.paraformer,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
|
||||
streams, results = [], []
|
||||
total_duration = 0
|
||||
|
||||
for i, wave_filename in enumerate(args.sound_files):
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
streams.append(s)
|
||||
if i % 10 == 0:
|
||||
recognizer.decode_streams(streams)
|
||||
results += [s.result.text for s in streams]
|
||||
streams = []
|
||||
print(f"Processed {i} files")
|
||||
# process the last batch
|
||||
if streams:
|
||||
recognizer.decode_streams(streams)
|
||||
results += [s.result.text for s in streams]
|
||||
end_time = time.time()
|
||||
print("Done!")
|
||||
|
||||
results_dict = {}
|
||||
for wave_filename, result in zip(args.sound_files, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
wave_basename = Path(wave_filename).stem
|
||||
results_dict[wave_basename] = result
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
if args.label:
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
labels_dict = {}
|
||||
with open(args.label, "r") as f:
|
||||
for line in f:
|
||||
# fields = line.strip().split(" ")
|
||||
# fields = [item for item in fields if item]
|
||||
# assert len(fields) == 4
|
||||
# prompt_text, prompt_audio, text, audio_path = fields
|
||||
|
||||
fields = line.strip().split("|")
|
||||
fields = [item for item in fields if item]
|
||||
assert len(fields) == 4
|
||||
audio_path, prompt_text, prompt_audio, text = fields
|
||||
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||
|
||||
final_results = []
|
||||
for key, value in results_dict.items():
|
||||
final_results.append((key, labels_dict[key], value))
|
||||
|
||||
store_transcripts(
|
||||
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||
)
|
||||
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||
|
||||
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||
print(f.readline()) # WER
|
||||
print(f.readline()) # Detailed errors
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user