use tokens.txt to replace bpe.model

This commit is contained in:
Fangjun Kuang 2023-10-11 11:26:52 +08:00
parent af2d83d8d7
commit 0945c5b379

View File

@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url) repo=$(basename $repo_url)
pushd $repo pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/pretrained.pt"
cd exp cd exp
@ -28,7 +27,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless8/export-onnx.py \ ./pruned_transducer_stateless8/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -50,8 +49,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
@ -66,7 +65,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -123,10 +122,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -412,12 +411,10 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = sp.get_piece_size() params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)