mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
fix wenet stateless5 jit export error (#735)
This commit is contained in:
parent
bd7fa2253d
commit
be6e08f69a
2
egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
Normal file → Executable file
2
egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py
Normal file → Executable file
@ -74,6 +74,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -184,6 +185,7 @@ def main():
|
|||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
# torch scriptabe.
|
# torch scriptabe.
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
1
egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
Symbolic link
1
egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py
|
Loading…
x
Reference in New Issue
Block a user