Fix export

This commit is contained in:
pkufool 2023-06-23 17:51:50 +08:00
parent 63e53bad59
commit ae47b739f0
9 changed files with 60 additions and 34 deletions

View File

@ -57,9 +57,9 @@ whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp: It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx - encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- decoder-epoch-99-avg-1.onnx - decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- joiner-epoch-99-avg-1.onnx - joiner-epoch-99-avg-1-chunk-16-left-64.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
""" """
@ -74,6 +74,7 @@ import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -585,9 +586,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
symbol_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"] params.blank_id = token_table["<blk>"]
params.vocab_size = len(symbol_table) params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)
@ -706,6 +707,8 @@ def main():
suffix = f"epoch-{params.epoch}" suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}" suffix += f"-avg-{params.avg}"
suffix += f"-chunk-{params.chunk_size}"
suffix += f"-left-{params.left_context_frames}"
opset_version = 13 opset_version = 13

View File

@ -71,6 +71,7 @@ import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -433,9 +434,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
symbol_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"] params.blank_id = token_table["<blk>"]
params.vocab_size = len(symbol_table) params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)

View File

@ -160,6 +160,7 @@ with the following commands:
import argparse import argparse
import logging import logging
import re
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
@ -178,6 +179,26 @@ from icefall.checkpoint import (
from icefall.utils import make_pad_mask, str2bool from icefall.utils import make_pad_mask, str2bool
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -402,9 +423,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
symbol_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"] params.blank_id = token_table["<blk>"]
params.vocab_size = len(symbol_table) params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)

View File

@ -256,12 +256,12 @@ def main():
s = "\n" s = "\n"
symbol_table = k2.SymbolTable.from_file(args.tokens) token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str: def token_ids_to_words(token_ids: List[int]) -> str:
text = "" text = ""
for i in token_ids: for i in token_ids:
text += symbol_table[i] text += token_table[i]
return text.replace("", " ").strip() return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps): for filename, hyp in zip(args.sound_files, hyps):

View File

@ -190,7 +190,7 @@ def main():
decoder = model.decoder decoder = model.decoder
joiner = model.joiner joiner = model.joiner
symbol_table = k2.SymbolTable.from_file(args.tokens) token_table = k2.SymbolTable.from_file(args.tokens)
context_size = decoder.context_size context_size = decoder.context_size
logging.info("Constructing Fbank computer") logging.info("Constructing Fbank computer")
@ -252,7 +252,7 @@ def main():
text = "" text = ""
for i in hyp[context_size:]: for i in hyp[context_size:]:
text += symbol_table[i] text += token_table[i]
text = text.replace("", " ").strip() text = text.replace("", " ").strip()
logging.info(args.sound_file) logging.info(args.sound_file)

View File

@ -524,11 +524,11 @@ def main():
hyp, hyp,
) )
symbol_table = k2.SymbolTable.from_file(args.tokens) token_table = k2.SymbolTable.from_file(args.tokens)
text = "" text = ""
for i in hyp[context_size:]: for i in hyp[context_size:]:
text += symbol_table[i] text += token_table[i]
text = text.replace("", " ").strip() text = text.replace("", " ").strip()
logging.info(args.sound_file) logging.info(args.sound_file)

View File

@ -396,12 +396,12 @@ def main():
) )
s = "\n" s = "\n"
symbol_table = k2.SymbolTable.from_file(args.tokens) token_table = k2.SymbolTable.from_file(args.tokens)
def token_ids_to_words(token_ids: List[int]) -> str: def token_ids_to_words(token_ids: List[int]) -> str:
text = "" text = ""
for i in token_ids: for i in token_ids:
text += symbol_table[i] text += token_table[i]
return text.replace("", " ").strip() return text.replace("", " ").strip()
for filename, hyp in zip(args.sound_files, hyps): for filename, hyp in zip(args.sound_files, hyps):

View File

@ -122,6 +122,7 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -262,11 +263,11 @@ def main():
params.update(vars(args)) params.update(vars(args))
symbol_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = symbol_table["<blk>"] params.blank_id = token_table["<blk>"]
params.unk_id = symbol_table["<unk>"] params.unk_id = token_table["<unk>"]
params.vocab_size = len(symbol_table) params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}") logging.info(f"{params}")
@ -328,7 +329,7 @@ def main():
def token_ids_to_words(token_ids: List[int]) -> str: def token_ids_to_words(token_ids: List[int]) -> str:
text = "" text = ""
for i in token_ids: for i in token_ids:
text += symbol_table[i] text += token_table[i]
return text.replace("", " ").strip() return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":

View File

@ -133,7 +133,7 @@ def get_parser():
def decode_one_batch( def decode_one_batch(
model: OnnxModel, symbol_table: k2.SymbolTable, batch: dict model: OnnxModel, token_table: k2.SymbolTable, batch: dict
) -> List[List[str]]: ) -> List[List[str]]:
"""Decode one batch and return the result. """Decode one batch and return the result.
Currently it only greedy_search is supported. Currently it only greedy_search is supported.
@ -141,7 +141,7 @@ def decode_one_batch(
Args: Args:
model: model:
The neural model. The neural model.
symbol_table: token_table:
Mapping ids to tokens. Mapping ids to tokens.
batch: batch:
It is the return value from iterating It is the return value from iterating
@ -164,14 +164,14 @@ def decode_one_batch(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
) )
hyps = [[symbol_table[h] for h in hyp] for hyp in hyps] hyps = [[token_table[h] for h in hyp] for hyp in hyps]
return hyps return hyps
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
model: nn.Module, model: nn.Module,
symbol_table: k2.SymbolTable, token_table: k2.SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: ) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset. """Decode dataset.
@ -180,7 +180,7 @@ def decode_dataset(
PyTorch's dataloader containing the dataset to decode. PyTorch's dataloader containing the dataset to decode.
model: model:
The neural model. The neural model.
symbol_table: token_table:
Mapping ids to tokens. Mapping ids to tokens.
Returns: Returns:
@ -206,7 +206,7 @@ def decode_dataset(
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, symbol_table=symbol_table, batch=batch) hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
@ -270,8 +270,8 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
symbol_table = k2.SymbolTable.from_file(args.tokens) token_table = k2.SymbolTable.from_file(args.tokens)
assert symbol_table[0] == "<blk>" assert token_table[0] == "<blk>"
logging.info(vars(args)) logging.info(vars(args))
@ -313,7 +313,7 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time() start_time = time.time()
results, total_duration = decode_dataset( results, total_duration = decode_dataset(
dl=test_dl, model=model, symbol_table=symbol_table dl=test_dl, model=model, token_table=token_table
) )
end_time = time.time() end_time = time.time()
elapsed_seconds = end_time - start_time elapsed_seconds = end_time - start_time