mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor updates
This commit is contained in:
parent
7bd679f7d5
commit
686d2d9787
@ -180,7 +180,7 @@ class OnnxAudioTagger(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
"""Please see the help information of Zipformer.forward
|
"""Please see the help information of Zipformer.forward
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -206,7 +206,7 @@ class OnnxAudioTagger(nn.Module):
|
|||||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||||
logits
|
logits
|
||||||
) # normalize the logits
|
) # normalize the logits
|
||||||
|
print(logits.shape)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@ -234,10 +234,10 @@ def export_audio_tagging_model_onnx(
|
|||||||
opset_version:
|
opset_version:
|
||||||
The opset version to use.
|
The opset version to use.
|
||||||
"""
|
"""
|
||||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
x = torch.zeros(1, 200, 80, dtype=torch.float32)
|
||||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
x_lens = torch.tensor([200], dtype=torch.int64)
|
||||||
|
|
||||||
model = torch.jit.trace(model, (x, x_lens))
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
@ -250,7 +250,7 @@ def export_audio_tagging_model_onnx(
|
|||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"x": {0: "N", 1: "T"},
|
"x": {0: "N", 1: "T"},
|
||||||
"x_lens": {0: "N"},
|
"x_lens": {0: "N"},
|
||||||
# "logits": {0: "N", 1: "T"},
|
"logits": {0: "N"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ class AudioTaggingModel(nn.Module):
|
|||||||
before padding.
|
before padding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A 3-D tensor of shape (N, T, num_classes).
|
A 3-D tensor of shape (N, num_classes).
|
||||||
"""
|
"""
|
||||||
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||||
padding_mask = make_pad_mask(encoder_out_lens)
|
padding_mask = make_pad_mask(encoder_out_lens)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user