mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Fix export
This commit is contained in:
parent
63e53bad59
commit
ae47b739f0
@ -57,9 +57,9 @@ whose value is "64,128,256,-1".
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
- decoder-epoch-99-avg-1.onnx
|
||||
- joiner-epoch-99-avg-1.onnx
|
||||
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||
|
||||
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
||||
"""
|
||||
@ -74,6 +74,7 @@ import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from export import num_tokens
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
@ -585,9 +586,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = symbol_table["<blk>"]
|
||||
params.vocab_size = len(symbol_table)
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -706,6 +707,8 @@ def main():
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
suffix += f"-avg-{params.avg}"
|
||||
suffix += f"-chunk-{params.chunk_size}"
|
||||
suffix += f"-left-{params.left_context_frames}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
|
@ -71,6 +71,7 @@ import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from export import num_tokens
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
@ -433,9 +434,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = symbol_table["<blk>"]
|
||||
params.vocab_size = len(symbol_table)
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -160,6 +160,7 @@ with the following commands:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -178,6 +179,26 @@ from icefall.checkpoint import (
|
||||
from icefall.utils import make_pad_mask, str2bool
|
||||
|
||||
|
||||
def num_tokens(
|
||||
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
) -> int:
|
||||
"""Return the number of tokens excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a token ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = token_table.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not disambig_pattern.match(s):
|
||||
ans.append(token_table[s])
|
||||
num_tokens = len(ans)
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -402,9 +423,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = symbol_table["<blk>"]
|
||||
params.vocab_size = len(symbol_table)
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -256,12 +256,12 @@ def main():
|
||||
|
||||
s = "\n"
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += symbol_table[i]
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
|
@ -190,7 +190,7 @@ def main():
|
||||
decoder = model.decoder
|
||||
joiner = model.joiner
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
context_size = decoder.context_size
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
@ -252,7 +252,7 @@ def main():
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text += token_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
|
@ -524,11 +524,11 @@ def main():
|
||||
hyp,
|
||||
)
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
text = ""
|
||||
for i in hyp[context_size:]:
|
||||
text += symbol_table[i]
|
||||
text += token_table[i]
|
||||
text = text.replace("▁", " ").strip()
|
||||
|
||||
logging.info(args.sound_file)
|
||||
|
@ -396,12 +396,12 @@ def main():
|
||||
)
|
||||
s = "\n"
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += symbol_table[i]
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
|
@ -122,6 +122,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from export import num_tokens
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
@ -262,11 +263,11 @@ def main():
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
params.blank_id = symbol_table["<blk>"]
|
||||
params.unk_id = symbol_table["<unk>"]
|
||||
params.vocab_size = len(symbol_table)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.unk_id = token_table["<unk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
@ -328,7 +329,7 @@ def main():
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += symbol_table[i]
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
|
@ -133,7 +133,7 @@ def get_parser():
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
model: OnnxModel, symbol_table: k2.SymbolTable, batch: dict
|
||||
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
|
||||
) -> List[List[str]]:
|
||||
"""Decode one batch and return the result.
|
||||
Currently it only greedy_search is supported.
|
||||
@ -141,7 +141,7 @@ def decode_one_batch(
|
||||
Args:
|
||||
model:
|
||||
The neural model.
|
||||
symbol_table:
|
||||
token_table:
|
||||
Mapping ids to tokens.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
@ -164,14 +164,14 @@ def decode_one_batch(
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||
)
|
||||
|
||||
hyps = [[symbol_table[h] for h in hyp] for hyp in hyps]
|
||||
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
symbol_table: k2.SymbolTable,
|
||||
token_table: k2.SymbolTable,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -180,7 +180,7 @@ def decode_dataset(
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
model:
|
||||
The neural model.
|
||||
symbol_table:
|
||||
token_table:
|
||||
Mapping ids to tokens.
|
||||
|
||||
Returns:
|
||||
@ -206,7 +206,7 @@ def decode_dataset(
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
hyps = decode_one_batch(model=model, symbol_table=symbol_table, batch=batch)
|
||||
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
@ -270,8 +270,8 @@ def main():
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||
assert symbol_table[0] == "<blk>"
|
||||
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||
assert token_table[0] == "<blk>"
|
||||
|
||||
logging.info(vars(args))
|
||||
|
||||
@ -313,7 +313,7 @@ def main():
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
start_time = time.time()
|
||||
results, total_duration = decode_dataset(
|
||||
dl=test_dl, model=model, symbol_table=symbol_table
|
||||
dl=test_dl, model=model, token_table=token_table
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
|
Loading…
x
Reference in New Issue
Block a user