mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update test_onnx.py
This commit is contained in:
parent
eb686b8da3
commit
fddfd2466f
@ -79,7 +79,9 @@ class OnnxModel:
|
|||||||
)
|
)
|
||||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||||
|
|
||||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
def __call__(
|
||||||
|
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokens:
|
tokens:
|
||||||
@ -100,7 +102,8 @@ class OnnxModel:
|
|||||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||||
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
||||||
self.model.get_inputs()[4].name: alpha.numpy(),
|
self.model.get_inputs()[4].name: speaker.numpy(),
|
||||||
|
self.model.get_inputs()[5].name: alpha.numpy(),
|
||||||
},
|
},
|
||||||
)[0]
|
)[0]
|
||||||
return torch.from_numpy(out)
|
return torch.from_numpy(out)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user