mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
fix export for aidatatang_200zh
This commit is contained in:
parent
71ee509e7d
commit
7cf3ea8b33
25
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
25
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py
Normal file → Executable file
@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
@ -20,7 +21,7 @@
|
||||
Usage:
|
||||
./pruned_transducer_stateless2/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--tokens data/lang_char/tokens.txt \
|
||||
--epoch 29 \
|
||||
--avg 19
|
||||
|
||||
@ -45,12 +46,13 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -85,10 +87,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
default="data/lang_char/tokens.txt",
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -122,10 +124,14 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
# Load id of the <blk> token and the vocab size
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -152,6 +158,7 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
|
1
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
1
egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
Loading…
x
Reference in New Issue
Block a user