fix export for stateless4 (#844)

This commit is contained in:
Fangjun Kuang 2023-01-16 20:26:36 +08:00 committed by GitHub
parent 2a463a420d
commit 0af3e7beda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 0 deletions

View File

@ -50,6 +50,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -261,6 +262,7 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.
# Otherwise, one of its arguments is a ragged tensor and is not # Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/lstmp.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py