fix export for aidatatang_200zh

This commit is contained in:
Fangjun Kuang 2024-01-26 12:17:48 +08:00
parent 71ee509e7d
commit 7cf3ea8b33
3 changed files with 18 additions and 9 deletions

View File

@ -1,3 +1,4 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
@ -20,7 +21,7 @@
Usage: Usage:
./pruned_transducer_stateless2/export.py \ ./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \ --tokens data/lang_char/tokens.txt \
--epoch 29 \ --epoch 29 \
--avg 19 --avg 19
@ -45,12 +46,13 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, str2bool
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -85,10 +87,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", default="data/lang_char/tokens.txt",
help="The lang dir", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -122,10 +124,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) # Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = 0 # Load id of the <blk> token and the vocab size
params.vocab_size = max(lexicon.tokens) + 1 # <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)
@ -152,6 +158,7 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.
# Otherwise, one of its arguments is a ragged tensor and is not # Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1 @@
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py