mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 23:24:17 +00:00
fix exporting tal_csasr recipe
This commit is contained in:
parent
7cf3ea8b33
commit
e1880b7413
@ -23,7 +23,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--lang-dir ./data/lang_char \
|
--tokens ./data/lang_char/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 24 \
|
--avg 24 \
|
||||||
--use-averaged-model True
|
--use-averaged-model True
|
||||||
@ -50,8 +50,9 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -60,8 +61,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -118,13 +118,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_char",
|
default="data/lang_char/tokens.txt",
|
||||||
help="""The lang dir
|
help="Path to the tokens.txt.",
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -160,13 +157,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
bpe_model = params.lang_dir + "/bpe.model"
|
# Load tokens.txt here
|
||||||
sp = spm.SentencePieceProcessor()
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
sp.load(bpe_model)
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -256,6 +254,7 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
1
egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
Symbolic link
1
egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
Loading…
x
Reference in New Issue
Block a user