mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix zipformer CI test (#1164)
This commit is contained in:
parent
a4402b88e6
commit
b8a17944e4
@ -856,6 +856,10 @@ def main():
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward
|
||||
model.encoder.__class__.non_streaming_forward = torch.jit.export(
|
||||
model.encoder.__class__.non_streaming_forward
|
||||
)
|
||||
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
@ -252,7 +252,7 @@ def main():
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
encoder_out, encoder_out_lens = model.encoder.non_streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user