Update test_onnx.py

This commit is contained in:
jinzr 2023-12-05 14:40:54 +08:00
parent eb686b8da3
commit fddfd2466f

View File

@ -79,7 +79,9 @@ class OnnxModel:
) )
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: def __call__(
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
) -> torch.Tensor:
""" """
Args: Args:
tokens: tokens:
@ -100,7 +102,8 @@ class OnnxModel:
self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(), self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: alpha.numpy(), self.model.get_inputs()[4].name: speaker.numpy(),
self.model.get_inputs()[5].name: alpha.numpy(),
}, },
)[0] )[0]
return torch.from_numpy(out) return torch.from_numpy(out)