mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Update test_onnx.py
This commit is contained in:
parent
8717043bdf
commit
eb686b8da3
@ -111,6 +111,10 @@ def main():
|
|||||||
|
|
||||||
tokenizer = Tokenizer(args.tokens)
|
tokenizer = Tokenizer(args.tokens)
|
||||||
|
|
||||||
|
with open(args.speakers) as f:
|
||||||
|
speaker_map = {line.strip(): i for i, line in enumerate(f)}
|
||||||
|
args.num_spks = len(speaker_map)
|
||||||
|
|
||||||
logging.info("About to create onnx model")
|
logging.info("About to create onnx model")
|
||||||
model = OnnxModel(args.model_filename)
|
model = OnnxModel(args.model_filename)
|
||||||
|
|
||||||
@ -118,7 +122,8 @@ def main():
|
|||||||
tokens = tokenizer.texts_to_token_ids([text])
|
tokens = tokenizer.texts_to_token_ids([text])
|
||||||
tokens = torch.tensor(tokens) # (1, T)
|
tokens = torch.tensor(tokens) # (1, T)
|
||||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||||
audio = model(tokens, tokens_lens) # (1, T')
|
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
|
||||||
|
audio = model(tokens, tokens_lens, speaker) # (1, T')
|
||||||
|
|
||||||
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||||
logging.info("Saved to test_onnx.wav")
|
logging.info("Saved to test_onnx.wav")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user