mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Update test_onnx.py
This commit is contained in:
parent
eb686b8da3
commit
fddfd2466f
@ -79,7 +79,9 @@ class OnnxModel:
|
||||
)
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||
def __call__(
|
||||
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens:
|
||||
@ -100,7 +102,8 @@ class OnnxModel:
|
||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
||||
self.model.get_inputs()[4].name: alpha.numpy(),
|
||||
self.model.get_inputs()[4].name: speaker.numpy(),
|
||||
self.model.get_inputs()[5].name: alpha.numpy(),
|
||||
},
|
||||
)[0]
|
||||
return torch.from_numpy(out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user