mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Update export-onnx.py
updated `export-oonx.py` to accept `tokens.txt` for blank_id and vocab_size
This commit is contained in:
parent
67acaf9431
commit
95ec9efcbd
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang
|
||||||
|
# Zengrui Jin)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script exports a transducer model from PyTorch to ONNX.
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
@ -28,7 +29,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless7/export-onnx.py \
|
./pruned_transducer_stateless7/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -49,9 +50,10 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
import re
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -123,10 +125,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
help="Path to the tokens.txt.",
|
||||||
help="Path to the BPE model",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -159,6 +160,26 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
|||||||
onnx.save(model, filename)
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def num_tokens(
|
||||||
|
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||||
|
) -> int:
|
||||||
|
"""Return the number of tokens excluding those from
|
||||||
|
disambiguation symbols.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
0 is not a token ID so it is excluded from the return value.
|
||||||
|
"""
|
||||||
|
symbols = token_table.symbols
|
||||||
|
ans = []
|
||||||
|
for s in symbols:
|
||||||
|
if not disambig_pattern.match(s):
|
||||||
|
ans.append(token_table[s])
|
||||||
|
num_tokens = len(ans)
|
||||||
|
if 0 in ans:
|
||||||
|
num_tokens -= 1
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
class OnnxEncoder(nn.Module):
|
class OnnxEncoder(nn.Module):
|
||||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
@ -411,12 +432,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user