mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Minor fixes to the onnx inference script for ljspeech matcha-tts. (#1838)
This commit is contained in:
parent
92ed1708c0
commit
ad966fb81d
20
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
20
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -57,6 +57,7 @@ function infer() {
|
|||||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||||
|
|
||||||
./matcha/infer.py \
|
./matcha/infer.py \
|
||||||
|
--num-buckets 2 \
|
||||||
--epoch 1 \
|
--epoch 1 \
|
||||||
--exp-dir ./matcha/exp \
|
--exp-dir ./matcha/exp \
|
||||||
--tokens data/tokens.txt \
|
--tokens data/tokens.txt \
|
||||||
@ -97,19 +98,23 @@ function export_onnx() {
|
|||||||
python3 ./matcha/export_onnx_hifigan.py
|
python3 ./matcha/export_onnx_hifigan.py
|
||||||
else
|
else
|
||||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
|
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ls -lh *.onnx
|
ls -lh *.onnx
|
||||||
|
|
||||||
python3 ./matcha/onnx_pretrained.py \
|
for v in v1 v2 v3; do
|
||||||
--acoustic-model ./model-steps-6.onnx \
|
python3 ./matcha/onnx_pretrained.py \
|
||||||
--vocoder ./hifigan_v1.onnx \
|
--acoustic-model ./model-steps-6.onnx \
|
||||||
--tokens ./data/tokens.txt \
|
--vocoder ./hifigan_$v.onnx \
|
||||||
--input-text "how are you doing?" \
|
--tokens ./data/tokens.txt \
|
||||||
--output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
|
--input-text "how are you doing?" \
|
||||||
|
--output-wav /icefall/generated-matcha-tts-steps-6-$v.wav
|
||||||
|
done
|
||||||
|
|
||||||
ls -lh /icefall/*.wav
|
ls -lh /icefall/*.wav
|
||||||
soxi /icefall/generated-matcha-tts-steps-6-v1.wav
|
soxi /icefall/generated-matcha-tts-steps-6-*.wav
|
||||||
}
|
}
|
||||||
|
|
||||||
prepare_data
|
prepare_data
|
||||||
@ -118,3 +123,4 @@ infer
|
|||||||
export_onnx
|
export_onnx
|
||||||
|
|
||||||
rm -rfv generator_v* matcha/exp
|
rm -rfv generator_v* matcha/exp
|
||||||
|
git checkout .
|
||||||
|
@ -163,7 +163,7 @@ def main():
|
|||||||
(x, x_lengths, temperature, length_scale),
|
(x, x_lengths, temperature, length_scale),
|
||||||
filename,
|
filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
input_names=["x", "x_length", "temperature", "length_scale"],
|
input_names=["x", "x_length", "noise_scale", "length_scale"],
|
||||||
output_names=["mel"],
|
output_names=["mel"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"x": {0: "N", 1: "L"},
|
"x": {0: "N", 1: "L"},
|
||||||
|
@ -89,6 +89,7 @@ class OnnxHifiGANModel:
|
|||||||
self.model.get_inputs()[0].name: x.numpy(),
|
self.model.get_inputs()[0].name: x.numpy(),
|
||||||
},
|
},
|
||||||
)[0]
|
)[0]
|
||||||
|
# audio: (batch_size, num_samples)
|
||||||
|
|
||||||
return torch.from_numpy(audio)
|
return torch.from_numpy(audio)
|
||||||
|
|
||||||
@ -97,19 +98,24 @@ class OnnxModel:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str,
|
||||||
|
tokens: str,
|
||||||
):
|
):
|
||||||
session_opts = ort.SessionOptions()
|
session_opts = ort.SessionOptions()
|
||||||
session_opts.inter_op_num_threads = 1
|
session_opts.inter_op_num_threads = 1
|
||||||
session_opts.intra_op_num_threads = 2
|
session_opts.intra_op_num_threads = 2
|
||||||
|
|
||||||
self.session_opts = session_opts
|
self.session_opts = session_opts
|
||||||
self.tokenizer = Tokenizer("./data/tokens.txt")
|
self.tokenizer = Tokenizer(tokens)
|
||||||
self.model = ort.InferenceSession(
|
self.model = ort.InferenceSession(
|
||||||
filename,
|
filename,
|
||||||
sess_options=self.session_opts,
|
sess_options=self.session_opts,
|
||||||
providers=["CPUExecutionProvider"],
|
providers=["CPUExecutionProvider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||||
|
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||||
|
self.sample_rate = int(metadata["sample_rate"])
|
||||||
|
|
||||||
for i in self.model.get_inputs():
|
for i in self.model.get_inputs():
|
||||||
print(i)
|
print(i)
|
||||||
|
|
||||||
@ -138,6 +144,7 @@ class OnnxModel:
|
|||||||
self.model.get_inputs()[3].name: length_scale.numpy(),
|
self.model.get_inputs()[3].name: length_scale.numpy(),
|
||||||
},
|
},
|
||||||
)[0]
|
)[0]
|
||||||
|
# mel: (batch_size, feat_dim, num_frames)
|
||||||
|
|
||||||
return torch.from_numpy(mel)
|
return torch.from_numpy(mel)
|
||||||
|
|
||||||
@ -147,7 +154,7 @@ def main():
|
|||||||
params = get_parser().parse_args()
|
params = get_parser().parse_args()
|
||||||
logging.info(vars(params))
|
logging.info(vars(params))
|
||||||
|
|
||||||
model = OnnxModel(params.acoustic_model)
|
model = OnnxModel(params.acoustic_model, params.tokens)
|
||||||
vocoder = OnnxHifiGANModel(params.vocoder)
|
vocoder = OnnxHifiGANModel(params.vocoder)
|
||||||
text = params.input_text
|
text = params.input_text
|
||||||
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||||
@ -164,15 +171,17 @@ def main():
|
|||||||
print("audio", audio.shape) # (1, 1, num_samples)
|
print("audio", audio.shape) # (1, 1, num_samples)
|
||||||
audio = audio.squeeze()
|
audio = audio.squeeze()
|
||||||
|
|
||||||
|
sample_rate = model.sample_rate
|
||||||
|
|
||||||
t = (end_t - start_t).total_seconds()
|
t = (end_t - start_t).total_seconds()
|
||||||
t2 = (end_t2 - start_t2).total_seconds()
|
t2 = (end_t2 - start_t2).total_seconds()
|
||||||
rtf_am = t * 22050 / audio.shape[-1]
|
rtf_am = t * sample_rate / audio.shape[-1]
|
||||||
rtf_vocoder = t2 * 22050 / audio.shape[-1]
|
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
|
||||||
print("RTF for acoustic model ", rtf_am)
|
print("RTF for acoustic model ", rtf_am)
|
||||||
print("RTF for vocoder", rtf_vocoder)
|
print("RTF for vocoder", rtf_vocoder)
|
||||||
|
|
||||||
# skip denoiser
|
# skip denoiser
|
||||||
sf.write(params.output_wav, audio, 22050, "PCM_16")
|
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
|
||||||
logging.info(f"Saved to {params.output_wav}")
|
logging.info(f"Saved to {params.output_wav}")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user