mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update README
This commit is contained in:
parent
b33d3820db
commit
9681263c0d
@ -36,3 +36,19 @@ To inference, use:
|
||||
--epoch 1000 \
|
||||
--tokens data/tokens.txt
|
||||
```
|
||||
|
||||
## Quality vs speed
|
||||
|
||||
If you feel that the trained model is slow at runtime, you can specify the
|
||||
argument `--model-type` during training. Possible values are:
|
||||
|
||||
- `low`, means **low** quality. The resulting model is very small in file size
|
||||
and runs very fast. The following is a wave file generatd by a `low` model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633
|
||||
|
||||
The text is "Ask not what your country can do for you; ask what you can do for your country."
|
||||
|
||||
- `medium`, means **medium** quality.
|
||||
- `high`, means **high** quality
|
||||
|
||||
|
@ -143,7 +143,7 @@ class OnnxModel(nn.Module):
|
||||
Return a tuple containing:
|
||||
- audio, generated wavform tensor, (B, T_wav)
|
||||
"""
|
||||
audio, _, _ = self.model.inference(
|
||||
audio, _, _ = self.model.generator.inference(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
noise_scale=noise_scale,
|
||||
@ -205,6 +205,11 @@ def export_model_onnx(
|
||||
},
|
||||
)
|
||||
|
||||
if model.model.spks is None:
|
||||
num_speakers = 1
|
||||
else:
|
||||
num_speakers = model.model.spks
|
||||
|
||||
meta_data = {
|
||||
"model_type": "vits",
|
||||
"version": "1",
|
||||
@ -213,8 +218,8 @@ def export_model_onnx(
|
||||
"language": "English",
|
||||
"voice": "en-us", # Choose your language appropriately
|
||||
"has_espeak": 1,
|
||||
"n_speakers": 1,
|
||||
"sample_rate": 22050, # Must match the real sample rate
|
||||
"n_speakers": num_speakers,
|
||||
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
@ -240,7 +245,6 @@ def main():
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model = model.generator
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
|
@ -54,6 +54,20 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="Ask not what your country can do for you; ask what you can do for your country.",
|
||||
help="Text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-filename",
|
||||
type=str,
|
||||
default="test_onnx.wav",
|
||||
help="Filename to save the generated wave file.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -61,7 +75,7 @@ class OnnxModel:
|
||||
def __init__(self, model_filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
@ -72,6 +86,9 @@ class OnnxModel:
|
||||
)
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
|
||||
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sample_rate = int(metadata["sample_rate"])
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -101,13 +118,14 @@ class OnnxModel:
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(args.tokens)
|
||||
|
||||
logging.info("About to create onnx model")
|
||||
model = OnnxModel(args.model_filename)
|
||||
|
||||
text = "I went there to see the land, the people and how their system works, end quote."
|
||||
text = args.text
|
||||
tokens = tokenizer.texts_to_token_ids(
|
||||
[text], intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
@ -115,8 +133,9 @@ def main():
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||
audio = model(tokens, tokens_lens) # (1, T')
|
||||
|
||||
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||
logging.info("Saved to test_onnx.wav")
|
||||
output_filename = args.output_filename
|
||||
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
|
||||
logging.info(f"Saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -18,7 +18,15 @@ import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import tacotron_cleaner.cleaners
|
||||
from piper_phonemize import phonemize_espeak
|
||||
|
||||
try:
|
||||
from piper_phonemize import phonemize_espeak
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease follow instructions in "
|
||||
"../prepare.sh to install piper-phonemize"
|
||||
)
|
||||
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user