Fix export

This commit is contained in:
pkufool 2023-06-23 17:51:50 +08:00
parent 63e53bad59
commit ae47b739f0
9 changed files with 60 additions and 34 deletions

View File

@ -57,9 +57,9 @@ whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
@ -74,6 +74,7 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@ -585,9 +586,9 @@ def main():
logging.info(f"device: {device}")
symbol_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"]
params.vocab_size = len(symbol_table)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
@ -706,6 +707,8 @@ def main():
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
suffix += f"-chunk-{params.chunk_size}"
suffix += f"-left-{params.left_context_frames}"
opset_version = 13

View File

@ -71,6 +71,7 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@ -433,9 +434,9 @@ def main():
logging.info(f"device: {device}")
symbol_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"]
params.vocab_size = len(symbol_table)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)

View File

@ -160,6 +160,7 @@ with the following commands:
import argparse
import logging
import re
from pathlib import Path
from typing import List, Tuple
@ -178,6 +179,26 @@ from icefall.checkpoint import (
from icefall.utils import make_pad_mask, str2bool
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -402,9 +423,9 @@ def main():
logging.info(f"device: {device}")
symbol_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"]
params.vocab_size = len(symbol_table)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)

View File

@ -256,12 +256,12 @@ def main():
s = "\n"
symbol_table = k2.SymbolTable.from_file(args.tokens)
token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += symbol_table[i]
text += token_table[i]
return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps):

View File

@ -190,7 +190,7 @@ def main():
decoder = model.decoder
joiner = model.joiner
symbol_table = k2.SymbolTable.from_file(args.tokens)
token_table = k2.SymbolTable.from_file(args.tokens)
context_size = decoder.context_size
logging.info("Constructing Fbank computer")
@ -252,7 +252,7 @@ def main():
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text += token_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)

View File

@ -524,11 +524,11 @@ def main():
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
token_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text += token_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)

View File

@ -396,12 +396,12 @@ def main():
)
s = "\n"
symbol_table = k2.SymbolTable.from_file(args.tokens)
token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += symbol_table[i]
text += token_table[i]
return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps):

View File

@ -122,6 +122,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
@ -262,11 +263,11 @@ def main():
params.update(vars(args))
symbol_table = k2.SymbolTable.from_file(params.tokens)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"]
params.unk_id = symbol_table["<unk>"]
params.vocab_size = len(symbol_table)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}")
@ -328,7 +329,7 @@ def main():
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += symbol_table[i]
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":

View File

@ -133,7 +133,7 @@ def get_parser():
def decode_one_batch(
model: OnnxModel, symbol_table: k2.SymbolTable, batch: dict
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
@ -141,7 +141,7 @@ def decode_one_batch(
Args:
model:
The neural model.
symbol_table:
token_table:
Mapping ids to tokens.
batch:
It is the return value from iterating
@ -164,14 +164,14 @@ def decode_one_batch(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
hyps = [[symbol_table[h] for h in hyp] for hyp in hyps]
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
symbol_table: k2.SymbolTable,
token_table: k2.SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
@ -180,7 +180,7 @@ def decode_dataset(
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
symbol_table:
token_table:
Mapping ids to tokens.
Returns:
@ -206,7 +206,7 @@ def decode_dataset(
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, symbol_table=symbol_table, batch=batch)
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
@ -270,8 +270,8 @@ def main():
device = torch.device("cpu")
logging.info(f"Device: {device}")
symbol_table = k2.SymbolTable.from_file(args.tokens)
assert symbol_table[0] == "<blk>"
token_table = k2.SymbolTable.from_file(args.tokens)
assert token_table[0] == "<blk>"
logging.info(vars(args))
@ -313,7 +313,7 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(
dl=test_dl, model=model, symbol_table=symbol_table
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time