Update export-onnx.py

updated `export-oonx.py` to accept `tokens.txt` for blank_id and vocab_size
This commit is contained in:
jinzr 2023-07-04 12:03:57 +08:00
parent 67acaf9431
commit 95ec9efcbd

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang
# Zengrui Jin)
"""
This script exports a transducer model from PyTorch to ONNX.
@ -28,7 +29,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -49,9 +50,10 @@ import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import re
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
@ -123,10 +125,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -159,6 +160,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
onnx.save(model, filename)
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
@ -411,12 +432,12 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)