mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
Fix torch.jit.script() export for streaming zipformer. (#1005)
This commit is contained in:
parent
7c7d9ab042
commit
e32658e620
@ -856,6 +856,7 @@ def main():
|
|||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
# torch scriptabe.
|
# torch scriptabe.
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -570,7 +570,6 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user