From bf6fb9f0e2c638039d498ac0fcbe581e3658e255 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Fri, 11 Aug 2023 21:05:36 +0800 Subject: [PATCH] minor fixes --- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 2 +- .github/scripts/test-ncnn-export.sh | 2 +- .../export.py | 20 ++++++++++--------- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index 32223716d..f6fe8c9b2 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do --model-filename $repo/exp/jit_trace.pt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ --G $repo/data/lm/G_4_gram.pt \ --method $m \ --sample-rate 16000 \ diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index ac16131d0..5aaf5d244 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -60,7 +60,7 @@ log "Export via torch.jit.trace()" --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ - \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index c191b5bcc..59a7eb589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -139,8 +139,8 @@ import argparse import logging from pathlib import Path +import k2 import onnxruntime -import sentencepiece as spm import torch import torch.nn as nn from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner @@ -154,7 +154,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -211,10 +211,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -675,12 +675,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params)