From fddfd2466ff18f5ed28e5d29206d86d110edcf2c Mon Sep 17 00:00:00 2001 From: jinzr Date: Tue, 5 Dec 2023 14:40:54 +0800 Subject: [PATCH] Update test_onnx.py --- egs/vctk/TTS/vits/test_onnx.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index 4f8e9da19..757e67fc1 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -79,7 +79,9 @@ class OnnxModel: ) logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: """ Args: tokens: @@ -100,7 +102,8 @@ class OnnxModel: self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: alpha.numpy(), + self.model.get_inputs()[4].name: speaker.numpy(), + self.model.get_inputs()[5].name: alpha.numpy(), }, )[0] return torch.from_numpy(out)