Fix torch.jit.script() export for streaming zipformer. (#1005)

This commit is contained in:
Fangjun Kuang 2023-04-17 16:13:30 +08:00 committed by GitHub
parent 7c7d9ab042
commit e32658e620
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 1 deletions

View File

@ -856,6 +856,7 @@ def main():
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -570,7 +570,6 @@ class Zipformer(EncoderInterface):
return x, lengths
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,