minor updates

This commit is contained in:
marcoyang 2024-03-29 19:08:21 +08:00
parent 7bd679f7d5
commit 686d2d9787
2 changed files with 7 additions and 7 deletions

View File

@ -180,7 +180,7 @@ class OnnxAudioTagger(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
"""Please see the help information of Zipformer.forward """Please see the help information of Zipformer.forward
Args: Args:
@ -206,7 +206,7 @@ class OnnxAudioTagger(nn.Module):
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits logits
) # normalize the logits ) # normalize the logits
print(logits.shape)
return logits return logits
@ -234,10 +234,10 @@ def export_audio_tagging_model_onnx(
opset_version: opset_version:
The opset version to use. The opset version to use.
""" """
x = torch.zeros(1, 100, 80, dtype=torch.float32) x = torch.zeros(1, 200, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64) x_lens = torch.tensor([200], dtype=torch.int64)
model = torch.jit.trace(model, (x, x_lens)) model = torch.jit.script(model)
torch.onnx.export( torch.onnx.export(
model, model,
@ -250,7 +250,7 @@ def export_audio_tagging_model_onnx(
dynamic_axes={ dynamic_axes={
"x": {0: "N", 1: "T"}, "x": {0: "N", 1: "T"},
"x_lens": {0: "N"}, "x_lens": {0: "N"},
# "logits": {0: "N", 1: "T"}, "logits": {0: "N"},
}, },
) )

View File

@ -144,7 +144,7 @@ class AudioTaggingModel(nn.Module):
before padding. before padding.
Returns: Returns:
A 3-D tensor of shape (N, T, num_classes). A 3-D tensor of shape (N, num_classes).
""" """
logits = self.classifier(encoder_out) # (N, T, num_classes) logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens) padding_mask = make_pad_mask(encoder_out_lens)