mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix zipformer CI test (#1164)
This commit is contained in:
parent
a4402b88e6
commit
b8a17944e4
@ -856,6 +856,10 @@ def main():
|
|||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
# torch scriptabe.
|
# torch scriptabe.
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward
|
||||||
|
model.encoder.__class__.non_streaming_forward = torch.jit.export(
|
||||||
|
model.encoder.__class__.non_streaming_forward
|
||||||
|
)
|
||||||
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
@ -252,7 +252,7 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder.non_streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lengths,
|
x_lens=feature_lengths,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user