Fix torchscript export for aishell (#969)

This commit is contained in:
Fangjun Kuang 2023-03-27 14:08:26 +08:00 committed by GitHub
parent 8c3ea93fc8
commit 35e21a0d2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 0 deletions

View File

@ -48,6 +48,7 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -244,6 +245,7 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.
# Otherwise, one of its arguments is a ragged tensor and is not # Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py