Fix torch.jit.script() export for pruned_transducer_stateless2

This commit is contained in:
Fangjun Kuang 2023-12-10 10:05:52 +08:00
parent df56aff31e
commit 26b9a5a7a1
3 changed files with 4 additions and 0 deletions

View File

@ -49,6 +49,7 @@ from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
@ -198,6 +199,7 @@ def main():
model.eval()
if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/lstmp.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py