From 0945c5b3793f243678f4be4e51864ec96894e1f7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Oct 2023 11:26:52 +0800 Subject: [PATCH] use tokens.txt to replace bpe.model --- .../export-onnx.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py index cf3bd78ba..3fef231a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export-onnx.py @@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -28,7 +27,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless8/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -50,8 +49,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -66,7 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +122,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -412,12 +411,10 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params)