fix wenet stateless5 jit export error (#735)

This commit is contained in:
Cesc 2022-12-05 23:35:10 +08:00 committed by GitHub
parent bd7fa2253d
commit be6e08f69a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 0 deletions

View File

@ -74,6 +74,7 @@ import logging
from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -184,6 +185,7 @@ def main():
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
convert_scaled_to_non_scaled(model, inplace=True)
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py