This commit is contained in:
marcoyang 2023-02-13 18:19:52 +08:00
parent a57c54124a
commit 56c2474c0d

View File

@ -79,7 +79,6 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -91,6 +90,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
@ -181,8 +181,7 @@ def get_parser():
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
encoder_model: nn.Module, encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
@ -204,8 +203,7 @@ def export_encoder_model_jit_trace(
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
decoder_model: nn.Module, decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
@ -226,8 +224,7 @@ def export_decoder_model_jit_trace(
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
joiner_model: nn.Module, joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()