updated the pruned_stateless_emformer_rnnt2 recipe

This commit is contained in:
jinzr 2023-07-23 01:00:34 +08:00
parent d6f4805226
commit 8dcb6da8c7

View File

@ -22,7 +22,7 @@
Usage: Usage:
./prunted_stateless_emformer_rnnt/export.py \ ./prunted_stateless_emformer_rnnt/export.py \
--exp-dir ./prunted_stateless_emformer_rnnt/exp \ --exp-dir ./prunted_stateless_emformer_rnnt/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -48,7 +48,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -58,7 +58,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -115,10 +115,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -154,13 +154,12 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> and <unk> are defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)