mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use torch.jit.script() for positional encoding
This commit is contained in:
parent
4cb2395186
commit
d9e7f02225
@ -296,18 +296,7 @@ def export_encoder_model_onnx(
|
||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||
|
||||
# It assumes that the maximum input, after downsampling, won't have more
|
||||
# than 10k frames.
|
||||
# The first downsampling factor is 2, so the maximum input
|
||||
# should contain less than 20k frames, e.g., less than 400 seconds,
|
||||
# i.e., 3.3 minutes
|
||||
#
|
||||
# Note: If you want to handle a longer input audio, please increase this
|
||||
# value. The downside is that it will increase the size of the model.
|
||||
max_len = 10000
|
||||
for name, m in encoder_model.named_modules():
|
||||
if isinstance(m, CompactRelPositionalEncoding):
|
||||
m.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
@ -536,7 +525,7 @@ def main():
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||
|
||||
encoder = OnnxEncoder(
|
||||
encoder=model.encoder,
|
||||
|
||||
@ -27,6 +27,7 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||
from zipformer import CompactRelPositionalEncoding
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
is_pnnx: bool = False,
|
||||
is_onnx: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
is_pnnx:
|
||||
True if we are going to export the model for PNNX.
|
||||
is_onnx:
|
||||
True if we are going to export the model for ONNX.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||
# We want to recreate the positional encoding vector when
|
||||
# the input changes, so we have to use torch.jit.script()
|
||||
# to replace torch.jit.trace()
|
||||
d[name] = torch.jit.script(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
|
||||
@ -1322,10 +1322,8 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(0) >= T * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
# if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user