mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge pull request #2 from csukuangfj/export_zipformer2_onnx
Use torch.jit.script() for position encoding
This commit is contained in:
commit
a2b8e3545b
@ -296,6 +296,8 @@ def export_encoder_model_onnx(
|
|||||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
encoder_model,
|
encoder_model,
|
||||||
(x, x_lens),
|
(x, x_lens),
|
||||||
@ -523,7 +525,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
encoder = OnnxEncoder(
|
encoder = OnnxEncoder(
|
||||||
encoder=model.encoder,
|
encoder=model.encoder,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||||
|
from zipformer import CompactRelPositionalEncoding
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||||
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
is_pnnx: bool = False,
|
is_pnnx: bool = False,
|
||||||
|
is_onnx: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
|
|||||||
If False, the input model is copied and we modify the copied version.
|
If False, the input model is copied and we modify the copied version.
|
||||||
is_pnnx:
|
is_pnnx:
|
||||||
True if we are going to export the model for PNNX.
|
True if we are going to export the model for PNNX.
|
||||||
|
is_onnx:
|
||||||
|
True if we are going to export the model for ONNX.
|
||||||
Return:
|
Return:
|
||||||
Return a model without scaled layers.
|
Return a model without scaled layers.
|
||||||
"""
|
"""
|
||||||
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
|
|||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||||
d[name] = nn.Identity()
|
d[name] = nn.Identity()
|
||||||
|
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||||
|
# We want to recreate the positional encoding vector when
|
||||||
|
# the input changes, so we have to use torch.jit.script()
|
||||||
|
# to replace torch.jit.trace()
|
||||||
|
d[name] = torch.jit.script(m)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
|||||||
@ -1305,12 +1305,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a CompactRelPositionalEncoding object."""
|
"""Construct a CompactRelPositionalEncoding object."""
|
||||||
super(CompactRelPositionalEncoding, self).__init__()
|
super(CompactRelPositionalEncoding, self).__init__()
|
||||||
if torch.jit.is_tracing():
|
|
||||||
# It assumes that the maximum input, after downsampling, won't have more than
|
|
||||||
# 10k frames.
|
|
||||||
# The first downsampling factor is 2, so the maximum input
|
|
||||||
# should contain less than 20k frames, e.g., less than 200 seconds, i.e., 3.33 minutes
|
|
||||||
max_len = 10000
|
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
self.dropout = Dropout2(dropout_rate)
|
self.dropout = Dropout2(dropout_rate)
|
||||||
@ -1327,10 +1321,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(0) >= T * 2 - 1:
|
if self.pe.size(0) >= T * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
|
||||||
x.device
|
|
||||||
):
|
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user