Update test_onnx.py

This commit is contained in:
jinzr 2023-12-05 14:40:54 +08:00
parent eb686b8da3
commit fddfd2466f

View File

@ -79,7 +79,9 @@ class OnnxModel:
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
def __call__(
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
) -> torch.Tensor:
"""
Args:
tokens:
@ -100,7 +102,8 @@ class OnnxModel:
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(),
self.model.get_inputs()[4].name: speaker.numpy(),
self.model.get_inputs()[5].name: alpha.numpy(),
},
)[0]
return torch.from_numpy(out)