From 95ec9efcbd193f0e70b392ca195bf6cec136554b Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 4 Jul 2023 12:03:57 +0800 Subject: [PATCH] Update export-onnx.py updated `export-oonx.py` to accept `tokens.txt` for blank_id and vocab_size --- .../export-onnx.py | 43 ++++++++++++++----- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index d2db92820..46924a726 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) """ This script exports a transducer model from PyTorch to ONNX. @@ -28,7 +29,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -49,9 +50,10 @@ import argparse import logging from pathlib import Path from typing import Dict, Tuple +import re +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -123,10 +125,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -159,6 +160,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): onnx.save(model, filename) +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + class OnnxEncoder(nn.Module): """A wrapper for Zipformer and the encoder_proj from the joiner""" @@ -411,12 +432,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params)