mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
clean infer codes
This commit is contained in:
parent
3ba6febe4f
commit
03d500a414
@ -1,45 +1,81 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# import bigvan
|
|
||||||
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/")
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
# from importlib.resources import files
|
|
||||||
# import sys
|
|
||||||
# sys.path.append(f"/home/yuekaiz/BigVGAN/")
|
|
||||||
# from bigvgan import BigVGAN
|
|
||||||
from bigvganinference import BigVGANInference
|
from bigvganinference import BigVGANInference
|
||||||
|
|
||||||
# from f5_tts.eval.utils_eval import (
|
|
||||||
# get_inference_prompt,
|
|
||||||
# get_librispeech_test_clean_metainfo,
|
|
||||||
# get_seedtts_testset_metainfo,
|
|
||||||
# )
|
|
||||||
# from f5_tts.infer.utils_infer import load_vocoder
|
|
||||||
from model.cfm import CFM
|
from model.cfm import CFM
|
||||||
from model.dit import DiT
|
from model.dit import DiT
|
||||||
from model.modules import MelSpec
|
from model.modules import MelSpec
|
||||||
from model.utils import convert_char_to_pinyin
|
from model.utils import convert_char_to_pinyin
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from train import get_tokenizer, load_pretrained_checkpoint
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
get_model,
|
||||||
|
get_tokenizer,
|
||||||
|
load_F5_TTS_pretrained_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def load_vocoder(device):
|
def get_parser():
|
||||||
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
parser = argparse.ArgumentParser(
|
||||||
model = BigVGANInference.from_pretrained(
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
|
||||||
)
|
)
|
||||||
model = model.eval().to(device)
|
|
||||||
return model
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="f5-tts/vocab.txt",
|
||||||
|
help="Path to the unique text tokens file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
|
||||||
|
help="Path to the unique text tokens file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nfe",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The number of steps for the neural ODE",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-file",
|
||||||
|
type=str,
|
||||||
|
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst",
|
||||||
|
help="The manifest file in seed_tts_eval format",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default="results",
|
||||||
|
help="The output directory to save the generated wavs",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def get_inference_prompt(
|
def get_inference_prompt(
|
||||||
@ -52,7 +88,7 @@ def get_inference_prompt(
|
|||||||
win_length=1024,
|
win_length=1024,
|
||||||
n_mel_channels=100,
|
n_mel_channels=100,
|
||||||
hop_length=256,
|
hop_length=256,
|
||||||
mel_spec_type="vocos",
|
mel_spec_type="bigvgan",
|
||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
use_truth_duration=False,
|
use_truth_duration=False,
|
||||||
infer_batch_size=1,
|
infer_batch_size=1,
|
||||||
@ -209,151 +245,54 @@ def get_seedtts_testset_metainfo(metalst):
|
|||||||
f.close()
|
f.close()
|
||||||
metainfo = []
|
metainfo = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if len(line.strip().split("|")) == 5:
|
assert len(line.strip().split("|")) == 4
|
||||||
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split("|")
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||||
elif len(line.strip().split("|")) == 4:
|
utt = Path(utt).stem
|
||||||
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||||
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
|
||||||
if not os.path.isabs(prompt_wav):
|
if not os.path.isabs(prompt_wav):
|
||||||
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||||
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
||||||
return metainfo
|
return metainfo
|
||||||
|
|
||||||
|
|
||||||
accelerator = Accelerator()
|
|
||||||
device = f"cuda:{accelerator.process_index}"
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------- Dataset Settings -------------------- #
|
|
||||||
|
|
||||||
target_sample_rate = 24000
|
|
||||||
n_mel_channels = 100
|
|
||||||
hop_length = 256
|
|
||||||
win_length = 1024
|
|
||||||
n_fft = 1024
|
|
||||||
target_rms = 0.1
|
|
||||||
|
|
||||||
# rel_path = str(files("f5_tts").joinpath("../../"))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# ---------------------- infer setting ---------------------- #
|
args = get_parser()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="batch inference")
|
accelerator = Accelerator()
|
||||||
|
device = f"cuda:{accelerator.process_index}"
|
||||||
parser.add_argument("-s", "--seed", default=None, type=int)
|
|
||||||
parser.add_argument("-d", "--dataset", default="Emilia_ZH_EN")
|
|
||||||
parser.add_argument("-n", "--expname", required=True)
|
|
||||||
parser.add_argument("-c", "--ckptstep", default=15000, type=int)
|
|
||||||
parser.add_argument(
|
|
||||||
"-m",
|
|
||||||
"--mel_spec_type",
|
|
||||||
default="bigvgan",
|
|
||||||
type=str,
|
|
||||||
choices=["bigvgan", "vocos"],
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"]
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("-nfe", "--nfestep", default=32, type=int)
|
|
||||||
parser.add_argument("-o", "--odemethod", default="euler")
|
|
||||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
|
||||||
|
|
||||||
parser.add_argument("-t", "--testset", required=True)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
seed = args.seed
|
|
||||||
dataset_name = args.dataset
|
|
||||||
exp_name = args.expname
|
|
||||||
ckpt_step = args.ckptstep
|
|
||||||
|
|
||||||
ckpt_path = "/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"
|
|
||||||
ckpt_path = "/home/yuekaiz/icefall_matcha/egs/wenetspeech4tts/TTS/exp/f5/checkpoint-15000.pt"
|
|
||||||
|
|
||||||
mel_spec_type = args.mel_spec_type
|
|
||||||
tokenizer = args.tokenizer
|
|
||||||
|
|
||||||
nfe_step = args.nfestep
|
|
||||||
ode_method = args.odemethod
|
|
||||||
sway_sampling_coef = args.swaysampling
|
|
||||||
|
|
||||||
testset = args.testset
|
|
||||||
|
|
||||||
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
|
|
||||||
cfg_strength = 2.0
|
|
||||||
speed = 1.0
|
|
||||||
use_truth_duration = False
|
|
||||||
no_ref_audio = False
|
|
||||||
|
|
||||||
model_cls = DiT
|
|
||||||
model_cfg = dict(
|
|
||||||
dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
|
|
||||||
)
|
|
||||||
metalst = "/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst"
|
|
||||||
metainfo = get_seedtts_testset_metainfo(metalst)
|
|
||||||
|
|
||||||
# path to save genereted wavs
|
|
||||||
output_dir = (
|
|
||||||
f"./"
|
|
||||||
f"results/{exp_name}_{ckpt_step}/{testset}/"
|
|
||||||
f"seed{seed}_{ode_method}_nfe{nfe_step}_{mel_spec_type}"
|
|
||||||
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
|
|
||||||
f"_cfg{cfg_strength}_speed{speed}"
|
|
||||||
f"{'_gt-dur' if use_truth_duration else ''}"
|
|
||||||
f"{'_no-ref-audio' if no_ref_audio else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||||
prompts_all = get_inference_prompt(
|
prompts_all = get_inference_prompt(
|
||||||
metainfo,
|
metainfo,
|
||||||
speed=speed,
|
speed=1.0,
|
||||||
tokenizer=tokenizer,
|
tokenizer="pinyin",
|
||||||
target_sample_rate=target_sample_rate,
|
target_sample_rate=24_000,
|
||||||
n_mel_channels=n_mel_channels,
|
n_mel_channels=100,
|
||||||
hop_length=hop_length,
|
hop_length=256,
|
||||||
mel_spec_type=mel_spec_type,
|
mel_spec_type="bigvgan",
|
||||||
target_rms=target_rms,
|
target_rms=0.1,
|
||||||
use_truth_duration=use_truth_duration,
|
use_truth_duration=False,
|
||||||
infer_batch_size=infer_batch_size,
|
infer_batch_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
vocoder = load_vocoder(device)
|
vocoder = BigVGANInference.from_pretrained(
|
||||||
|
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||||
# Tokenizer
|
|
||||||
vocab_char_map, vocab_size = get_tokenizer("./f5-tts/vocab.txt")
|
|
||||||
|
|
||||||
# Model
|
|
||||||
model = CFM(
|
|
||||||
transformer=model_cls(
|
|
||||||
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
|
|
||||||
),
|
|
||||||
mel_spec_kwargs=dict(
|
|
||||||
n_fft=n_fft,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length,
|
|
||||||
n_mel_channels=n_mel_channels,
|
|
||||||
target_sample_rate=target_sample_rate,
|
|
||||||
mel_spec_type=mel_spec_type,
|
|
||||||
),
|
|
||||||
odeint_kwargs=dict(
|
|
||||||
method=ode_method,
|
|
||||||
),
|
|
||||||
vocab_char_map=vocab_char_map,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
dtype = torch.float32 if mel_spec_type == "bigvgan" else None
|
|
||||||
# model = load_pretrained_checkpoint(model, ckpt_path)
|
|
||||||
_ = load_checkpoint(
|
|
||||||
ckpt_path,
|
|
||||||
model=model,
|
|
||||||
)
|
)
|
||||||
model = model.eval().to(device)
|
vocoder = vocoder.eval().to(device)
|
||||||
|
|
||||||
if not os.path.exists(output_dir) and accelerator.is_main_process:
|
model = get_model(args).eval().to(device)
|
||||||
os.makedirs(output_dir)
|
checkpoint = torch.load(args.model_path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
if "model_state_dict" or "ema_model_state_dict" in checkpoint:
|
||||||
|
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
|
||||||
|
else:
|
||||||
|
_ = load_checkpoint(
|
||||||
|
args.model_path,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
# start batch inference
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
@ -378,25 +317,23 @@ def main():
|
|||||||
text=final_text_list,
|
text=final_text_list,
|
||||||
duration=total_mel_lens,
|
duration=total_mel_lens,
|
||||||
lens=ref_mel_lens,
|
lens=ref_mel_lens,
|
||||||
steps=nfe_step,
|
steps=args.nfe,
|
||||||
cfg_strength=cfg_strength,
|
cfg_strength=2.0,
|
||||||
sway_sampling_coef=sway_sampling_coef,
|
sway_sampling_coef=args.swaysampling,
|
||||||
no_ref_audio=no_ref_audio,
|
no_ref_audio=False,
|
||||||
seed=seed,
|
seed=args.seed,
|
||||||
)
|
)
|
||||||
# Final result
|
|
||||||
for i, gen in enumerate(generated):
|
for i, gen in enumerate(generated):
|
||||||
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||||
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||||
if mel_spec_type == "vocos":
|
|
||||||
generated_wave = vocoder.decode(gen_mel_spec).cpu()
|
|
||||||
elif mel_spec_type == "bigvgan":
|
|
||||||
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
|
||||||
|
|
||||||
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||||
|
target_rms = 0.1
|
||||||
|
target_sample_rate = 24_000
|
||||||
if ref_rms_list[i] < target_rms:
|
if ref_rms_list[i] < target_rms:
|
||||||
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||||
torchaudio.save(
|
torchaudio.save(
|
||||||
f"{output_dir}/{utts[i]}.wav",
|
f"{args.output_dir}/{utts[i]}.wav",
|
||||||
generated_wave,
|
generated_wave,
|
||||||
target_sample_rate,
|
target_sample_rate,
|
||||||
)
|
)
|
||||||
@ -408,4 +345,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
@ -409,7 +409,7 @@ def get_model(params):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_checkpoint(
|
def load_F5_TTS_pretrained_checkpoint(
|
||||||
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
||||||
):
|
):
|
||||||
# model = model.to(dtype)
|
# model = model.to(dtype)
|
||||||
@ -937,7 +937,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
# model = load_pretrained_checkpoint(model, params.pretrained_model_path)
|
# model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
||||||
|
|||||||
@ -1,3 +1,15 @@
|
|||||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||||
|
#bigvganinference
|
||||||
|
model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||||
|
manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt
|
||||||
|
manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst
|
||||||
|
# get wenetspeech4tts
|
||||||
|
manifest_base_stem=$(basename $manifest)
|
||||||
|
mainfest_base_stem=${manifest_base_stem%.*}
|
||||||
|
output_dir=./results/f5-tts-pretrained/$mainfest_base_stem
|
||||||
|
|
||||||
accelerate launch f5-tts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
|
|
||||||
|
pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece
|
||||||
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1
|
||||||
|
|
||||||
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
|
|||||||
27
egs/wenetspeech4tts/TTS/local/compute_wer.sh
Normal file
27
egs/wenetspeech4tts/TTS/local/compute_wer.sh
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
wav_dir=$1
|
||||||
|
wav_files=$(ls $wav_dir/*.wav)
|
||||||
|
# wav_files=$(echo $wav_files | cut -d " " -f 1)
|
||||||
|
# if wav_files is empty, then exit
|
||||||
|
if [ -z "$wav_files" ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
label_file=$2
|
||||||
|
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||||
|
|
||||||
|
if [ ! -d $model_path ]; then
|
||||||
|
pip install sherpa-onnx
|
||||||
|
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 local/offline-decode-files.py \
|
||||||
|
--tokens=$model_path/tokens.txt \
|
||||||
|
--paraformer=$model_path/model.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=24000 \
|
||||||
|
--log-dir $wav_dir \
|
||||||
|
--feature-dim=80 \
|
||||||
|
--label $label_file \
|
||||||
|
$wav_files
|
||||||
495
egs/wenetspeech4tts/TTS/local/offline-decode-files.py
Executable file
495
egs/wenetspeech4tts/TTS/local/offline-decode-files.py
Executable file
@ -0,0 +1,495 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 by manyeyes
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||||
|
file(s) with a non-streaming model.
|
||||||
|
|
||||||
|
(1) For paraformer
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--paraformer=/path/to/paraformer.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(2) For transducer models from icefall
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--encoder=/path/to/encoder.onnx \
|
||||||
|
--decoder=/path/to/decoder.onnx \
|
||||||
|
--joiner=/path/to/joiner.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(3) For CTC models from NeMo
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||||
|
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--whisper-task=transcribe \
|
||||||
|
--num-threads=1 \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(5) For CTC models from WeNet
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||||
|
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(6) For tdnn models of the yesno recipe from icefall
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feature-dim=23 \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
|
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||||
|
used in this file.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="Path to tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The file containing hotwords, one words/phrases per line, like
|
||||||
|
HELLO WORLD
|
||||||
|
你好世界
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-score",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="""
|
||||||
|
The hotword score of each token for biasing word/phrase. Used only if
|
||||||
|
--hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modeling-unit",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||||
|
Used only when hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-vocab",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||||
|
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||||
|
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||||
|
and modeling-unit is bpe or cjkchar+bpe.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from Paraformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nemo-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from NeMo CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--wenet-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from WeNet CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tdnn-model",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of threads for neural network computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-language",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="""It specifies the spoken language in the input audio file.
|
||||||
|
Example values: en, fr, de, zh, jp.
|
||||||
|
Available languages for multilingual models can be found at
|
||||||
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
If not specified, we infer the language from the input audio file.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-task",
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
help="""For multilingual models, if you specify translate, the output
|
||||||
|
will be in English.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-tail-paddings",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="""Number of tail padding frames.
|
||||||
|
We have removed the 30-second constraint from whisper, so you need to
|
||||||
|
choose the amount of tail padding frames by yourself.
|
||||||
|
Use -1 to use a default value for tail padding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="True to show debug messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="""Sample rate of the feature extractor. Must match the one
|
||||||
|
expected by the model. Note: The input sound files can have a
|
||||||
|
different sample rate from this argument.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feature-dim",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Feature dimension. Must match the one expected by the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||||
|
"format with a single channel, and each sample has 16-bit, "
|
||||||
|
"i.e., int16_t. "
|
||||||
|
"The sample rate of the file can be arbitrary and does not need to "
|
||||||
|
"be 16 kHz",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="wav_base_name label",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_file_exists(filename: str):
|
||||||
|
assert Path(filename).is_file(), (
|
||||||
|
f"{filename} does not exist!\n"
|
||||||
|
"Please refer to "
|
||||||
|
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wave_filename:
|
||||||
|
Path to a wave file. It should be single channel and can be of type
|
||||||
|
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- A 1-D array of dtype np.float32 containing the samples,
|
||||||
|
which are normalized to the range [-1, 1].
|
||||||
|
- Sample rate of the wave file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||||
|
assert (
|
||||||
|
samples.ndim == 1
|
||||||
|
), f"Expected single channel, but got {samples.ndim} channels."
|
||||||
|
|
||||||
|
samples_float32 = samples.astype(np.float32)
|
||||||
|
|
||||||
|
return samples_float32, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text_alimeeting(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert_file_exists(args.tokens)
|
||||||
|
assert args.num_threads > 0, args.num_threads
|
||||||
|
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
|
paraformer=args.paraformer,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Started!")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
streams, results = [], []
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
for i, wave_filename in enumerate(args.sound_files):
|
||||||
|
assert_file_exists(wave_filename)
|
||||||
|
samples, sample_rate = read_wave(wave_filename)
|
||||||
|
duration = len(samples) / sample_rate
|
||||||
|
total_duration += duration
|
||||||
|
s = recognizer.create_stream()
|
||||||
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
|
streams.append(s)
|
||||||
|
if i % 10 == 0:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
streams = []
|
||||||
|
print(f"Processed {i} files")
|
||||||
|
# process the last batch
|
||||||
|
if streams:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
end_time = time.time()
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
for wave_filename, result in zip(args.sound_files, results):
|
||||||
|
print(f"{wave_filename}\n{result}")
|
||||||
|
print("-" * 10)
|
||||||
|
wave_basename = Path(wave_filename).stem
|
||||||
|
results_dict[wave_basename] = result
|
||||||
|
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
rtf = elapsed_seconds / total_duration
|
||||||
|
print(f"num_threads: {args.num_threads}")
|
||||||
|
print(f"decoding_method: {args.decoding_method}")
|
||||||
|
print(f"Wave duration: {total_duration:.3f} s")
|
||||||
|
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
|
print(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
if args.label:
|
||||||
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
labels_dict = {}
|
||||||
|
with open(args.label, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
# fields = line.strip().split(" ")
|
||||||
|
# fields = [item for item in fields if item]
|
||||||
|
# assert len(fields) == 4
|
||||||
|
# prompt_text, prompt_audio, text, audio_path = fields
|
||||||
|
|
||||||
|
fields = line.strip().split("|")
|
||||||
|
fields = [item for item in fields if item]
|
||||||
|
assert len(fields) == 4
|
||||||
|
audio_path, prompt_text, prompt_audio, text = fields
|
||||||
|
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for key, value in results_dict.items():
|
||||||
|
final_results.append((key, labels_dict[key], value))
|
||||||
|
|
||||||
|
store_transcripts(
|
||||||
|
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||||
|
)
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||||
|
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||||
|
print(f.readline()) # WER
|
||||||
|
print(f.readline()) # Detailed errors
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user