diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index 4f8e9da19..757e67fc1 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -79,7 +79,9 @@ class OnnxModel: ) logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: """ Args: tokens: @@ -100,7 +102,8 @@ class OnnxModel: self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: alpha.numpy(), + self.model.get_inputs()[4].name: speaker.numpy(), + self.model.get_inputs()[5].name: alpha.numpy(), }, )[0] return torch.from_numpy(out)