Fix zipformer CI test (#1164)

This commit is contained in:
Fangjun Kuang 2023-07-05 10:23:35 +08:00 committed by GitHub
parent a4402b88e6
commit b8a17944e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 1 deletions

View File

@ -856,6 +856,10 @@ def main():
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward
model.encoder.__class__.non_streaming_forward = torch.jit.export(
model.encoder.__class__.non_streaming_forward
)
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
logging.info("Using torch.jit.script")
model = torch.jit.script(model)

View File

@ -252,7 +252,7 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(
encoder_out, encoder_out_lens = model.encoder.non_streaming_forward(
x=features,
x_lens=feature_lengths,
)