mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
minor updates
This commit is contained in:
parent
7bd679f7d5
commit
686d2d9787
@ -180,7 +180,7 @@ class OnnxAudioTagger(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
"""Please see the help information of Zipformer.forward
|
||||
|
||||
Args:
|
||||
@ -206,7 +206,7 @@ class OnnxAudioTagger(nn.Module):
|
||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||
logits
|
||||
) # normalize the logits
|
||||
|
||||
print(logits.shape)
|
||||
return logits
|
||||
|
||||
|
||||
@ -234,10 +234,10 @@ def export_audio_tagging_model_onnx(
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
x = torch.zeros(1, 200, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([200], dtype=torch.int64)
|
||||
|
||||
model = torch.jit.trace(model, (x, x_lens))
|
||||
model = torch.jit.script(model)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
@ -250,7 +250,7 @@ def export_audio_tagging_model_onnx(
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "T"},
|
||||
"x_lens": {0: "N"},
|
||||
# "logits": {0: "N", 1: "T"},
|
||||
"logits": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -144,7 +144,7 @@ class AudioTaggingModel(nn.Module):
|
||||
before padding.
|
||||
|
||||
Returns:
|
||||
A 3-D tensor of shape (N, T, num_classes).
|
||||
A 3-D tensor of shape (N, num_classes).
|
||||
"""
|
||||
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||
padding_mask = make_pad_mask(encoder_out_lens)
|
||||
|
Loading…
x
Reference in New Issue
Block a user