This commit is contained in:
marcoyang 2023-02-13 18:19:52 +08:00
parent a57c54124a
commit 56c2474c0d

View File

@ -79,7 +79,6 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
@ -91,6 +90,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import str2bool from icefall.utils import str2bool
@ -181,8 +181,7 @@ def get_parser():
def export_encoder_model_jit_trace( def export_encoder_model_jit_trace(
encoder_model: nn.Module, encoder_model: nn.Module, encoder_filename: str,
encoder_filename: str,
) -> None: ) -> None:
"""Export the given encoder model with torch.jit.trace() """Export the given encoder model with torch.jit.trace()
@ -204,8 +203,7 @@ def export_encoder_model_jit_trace(
def export_decoder_model_jit_trace( def export_decoder_model_jit_trace(
decoder_model: nn.Module, decoder_model: nn.Module, decoder_filename: str,
decoder_filename: str,
) -> None: ) -> None:
"""Export the given decoder model with torch.jit.trace() """Export the given decoder model with torch.jit.trace()
@ -226,8 +224,7 @@ def export_decoder_model_jit_trace(
def export_joiner_model_jit_trace( def export_joiner_model_jit_trace(
joiner_model: nn.Module, joiner_model: nn.Module, joiner_filename: str,
joiner_filename: str,
) -> None: ) -> None:
"""Export the given joiner model with torch.jit.trace() """Export the given joiner model with torch.jit.trace()