mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix export
This commit is contained in:
parent
63e53bad59
commit
ae47b739f0
@ -57,9 +57,9 @@ whose value is "64,128,256,-1".
|
|||||||
|
|
||||||
It will generate the following 3 files inside $repo/exp:
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
- encoder-epoch-99-avg-1.onnx
|
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||||
- decoder-epoch-99-avg-1.onnx
|
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||||
- joiner-epoch-99-avg-1.onnx
|
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx
|
||||||
|
|
||||||
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
||||||
"""
|
"""
|
||||||
@ -74,6 +74,7 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from export import num_tokens
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -585,9 +586,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.blank_id = symbol_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = len(symbol_table)
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -706,6 +707,8 @@ def main():
|
|||||||
suffix = f"epoch-{params.epoch}"
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
suffix += f"-avg-{params.avg}"
|
suffix += f"-avg-{params.avg}"
|
||||||
|
suffix += f"-chunk-{params.chunk_size}"
|
||||||
|
suffix += f"-left-{params.left_context_frames}"
|
||||||
|
|
||||||
opset_version = 13
|
opset_version = 13
|
||||||
|
|
||||||
|
|||||||
@ -71,6 +71,7 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from export import num_tokens
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -433,9 +434,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.blank_id = symbol_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = len(symbol_table)
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -160,6 +160,7 @@ with the following commands:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
@ -178,6 +179,26 @@ from icefall.checkpoint import (
|
|||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def num_tokens(
|
||||||
|
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||||
|
) -> int:
|
||||||
|
"""Return the number of tokens excluding those from
|
||||||
|
disambiguation symbols.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
0 is not a token ID so it is excluded from the return value.
|
||||||
|
"""
|
||||||
|
symbols = token_table.symbols
|
||||||
|
ans = []
|
||||||
|
for s in symbols:
|
||||||
|
if not disambig_pattern.match(s):
|
||||||
|
ans.append(token_table[s])
|
||||||
|
num_tokens = len(ans)
|
||||||
|
if 0 in ans:
|
||||||
|
num_tokens -= 1
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -402,9 +423,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.blank_id = symbol_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = len(symbol_table)
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
|||||||
@ -256,12 +256,12 @@ def main():
|
|||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
text = ""
|
text = ""
|
||||||
for i in token_ids:
|
for i in token_ids:
|
||||||
text += symbol_table[i]
|
text += token_table[i]
|
||||||
return text.replace("▁", " ").strip()
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
for filename, hyp in zip(args.sound_files, hyps):
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
|||||||
@ -190,7 +190,7 @@ def main():
|
|||||||
decoder = model.decoder
|
decoder = model.decoder
|
||||||
joiner = model.joiner
|
joiner = model.joiner
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
context_size = decoder.context_size
|
context_size = decoder.context_size
|
||||||
|
|
||||||
logging.info("Constructing Fbank computer")
|
logging.info("Constructing Fbank computer")
|
||||||
@ -252,7 +252,7 @@ def main():
|
|||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
for i in hyp[context_size:]:
|
for i in hyp[context_size:]:
|
||||||
text += symbol_table[i]
|
text += token_table[i]
|
||||||
text = text.replace("▁", " ").strip()
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
logging.info(args.sound_file)
|
logging.info(args.sound_file)
|
||||||
|
|||||||
@ -524,11 +524,11 @@ def main():
|
|||||||
hyp,
|
hyp,
|
||||||
)
|
)
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
for i in hyp[context_size:]:
|
for i in hyp[context_size:]:
|
||||||
text += symbol_table[i]
|
text += token_table[i]
|
||||||
text = text.replace("▁", " ").strip()
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
logging.info(args.sound_file)
|
logging.info(args.sound_file)
|
||||||
|
|||||||
@ -396,12 +396,12 @@ def main():
|
|||||||
)
|
)
|
||||||
s = "\n"
|
s = "\n"
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
text = ""
|
text = ""
|
||||||
for i in token_ids:
|
for i in token_ids:
|
||||||
text += symbol_table[i]
|
text += token_table[i]
|
||||||
return text.replace("▁", " ").strip()
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
for filename, hyp in zip(args.sound_files, hyps):
|
for filename, hyp in zip(args.sound_files, hyps):
|
||||||
|
|||||||
@ -122,6 +122,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
from export import num_tokens
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
@ -262,11 +263,11 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
params.blank_id = symbol_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = symbol_table["<unk>"]
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = len(symbol_table)
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -328,7 +329,7 @@ def main():
|
|||||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
text = ""
|
text = ""
|
||||||
for i in token_ids:
|
for i in token_ids:
|
||||||
text += symbol_table[i]
|
text += token_table[i]
|
||||||
return text.replace("▁", " ").strip()
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
|
|||||||
@ -133,7 +133,7 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
model: OnnxModel, symbol_table: k2.SymbolTable, batch: dict
|
model: OnnxModel, token_table: k2.SymbolTable, batch: dict
|
||||||
) -> List[List[str]]:
|
) -> List[List[str]]:
|
||||||
"""Decode one batch and return the result.
|
"""Decode one batch and return the result.
|
||||||
Currently it only greedy_search is supported.
|
Currently it only greedy_search is supported.
|
||||||
@ -141,7 +141,7 @@ def decode_one_batch(
|
|||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
symbol_table:
|
token_table:
|
||||||
Mapping ids to tokens.
|
Mapping ids to tokens.
|
||||||
batch:
|
batch:
|
||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
@ -164,14 +164,14 @@ def decode_one_batch(
|
|||||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
hyps = [[symbol_table[h] for h in hyp] for hyp in hyps]
|
hyps = [[token_table[h] for h in hyp] for hyp in hyps]
|
||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
symbol_table: k2.SymbolTable,
|
token_table: k2.SymbolTable,
|
||||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ def decode_dataset(
|
|||||||
PyTorch's dataloader containing the dataset to decode.
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
symbol_table:
|
token_table:
|
||||||
Mapping ids to tokens.
|
Mapping ids to tokens.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -206,7 +206,7 @@ def decode_dataset(
|
|||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||||
|
|
||||||
hyps = decode_one_batch(model=model, symbol_table=symbol_table, batch=batch)
|
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
|
||||||
|
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
@ -270,8 +270,8 @@ def main():
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
token_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
assert symbol_table[0] == "<blk>"
|
assert token_table[0] == "<blk>"
|
||||||
|
|
||||||
logging.info(vars(args))
|
logging.info(vars(args))
|
||||||
|
|
||||||
@ -313,7 +313,7 @@ def main():
|
|||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results, total_duration = decode_dataset(
|
results, total_duration = decode_dataset(
|
||||||
dl=test_dl, model=model, symbol_table=symbol_table
|
dl=test_dl, model=model, token_table=token_table
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_seconds = end_time - start_time
|
elapsed_seconds = end_time - start_time
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user