Update export-onnx.py

updated `export-oonx.py` to accept `tokens.txt` for blank_id and vocab_size
This commit is contained in:
jinzr 2023-07-04 12:03:57 +08:00
parent 67acaf9431
commit 95ec9efcbd

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang
# Zengrui Jin)
""" """
This script exports a transducer model from PyTorch to ONNX. This script exports a transducer model from PyTorch to ONNX.
@ -28,7 +29,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless7/export-onnx.py \ ./pruned_transducer_stateless7/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -49,9 +50,10 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple
import re
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
@ -123,10 +125,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", help="Path to the tokens.txt.",
help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
@ -159,6 +160,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
onnx.save(model, filename) onnx.save(model, filename)
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
class OnnxEncoder(nn.Module): class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner""" """A wrapper for Zipformer and the encoder_proj from the joiner"""
@ -411,12 +432,12 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)