Fix torch.jit.script() export for pruned_transducer_stateless2 (#1410)

This commit is contained in:
Fangjun Kuang 2023-12-10 11:38:39 +08:00 committed by GitHub
parent df56aff31e
commit b0f70c9d04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 0 deletions

View File

@ -49,6 +49,7 @@ from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
@ -198,6 +199,7 @@ def main():
model.eval()
if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/lstmp.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py