mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use torch.jit.script() for positional encoding
This commit is contained in:
parent
4cb2395186
commit
d9e7f02225
@ -296,18 +296,7 @@ def export_encoder_model_onnx(
|
|||||||
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
x_lens = torch.tensor([100], dtype=torch.int64)
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
# It assumes that the maximum input, after downsampling, won't have more
|
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
# than 10k frames.
|
|
||||||
# The first downsampling factor is 2, so the maximum input
|
|
||||||
# should contain less than 20k frames, e.g., less than 400 seconds,
|
|
||||||
# i.e., 3.3 minutes
|
|
||||||
#
|
|
||||||
# Note: If you want to handle a longer input audio, please increase this
|
|
||||||
# value. The downside is that it will increase the size of the model.
|
|
||||||
max_len = 10000
|
|
||||||
for name, m in encoder_model.named_modules():
|
|
||||||
if isinstance(m, CompactRelPositionalEncoding):
|
|
||||||
m.extend_pe(torch.tensor(0.0).expand(max_len))
|
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
encoder_model,
|
encoder_model,
|
||||||
@ -536,7 +525,7 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
encoder = OnnxEncoder(
|
encoder = OnnxEncoder(
|
||||||
encoder=model.encoder,
|
encoder=model.encoder,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||||
|
from zipformer import CompactRelPositionalEncoding
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||||
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
is_pnnx: bool = False,
|
is_pnnx: bool = False,
|
||||||
|
is_onnx: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
|
|||||||
If False, the input model is copied and we modify the copied version.
|
If False, the input model is copied and we modify the copied version.
|
||||||
is_pnnx:
|
is_pnnx:
|
||||||
True if we are going to export the model for PNNX.
|
True if we are going to export the model for PNNX.
|
||||||
|
is_onnx:
|
||||||
|
True if we are going to export the model for ONNX.
|
||||||
Return:
|
Return:
|
||||||
Return a model without scaled layers.
|
Return a model without scaled layers.
|
||||||
"""
|
"""
|
||||||
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
|
|||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||||
d[name] = nn.Identity()
|
d[name] = nn.Identity()
|
||||||
|
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||||
|
# We want to recreate the positional encoding vector when
|
||||||
|
# the input changes, so we have to use torch.jit.script()
|
||||||
|
# to replace torch.jit.trace()
|
||||||
|
d[name] = torch.jit.script(m)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
|||||||
@ -1322,10 +1322,8 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(0) >= T * 2 - 1:
|
if self.pe.size(0) >= T * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
# if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||||
x.device
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
):
|
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user