From ae47b739f0c8091c952df18088d4f055d25ed556 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 23 Jun 2023 17:51:50 +0800 Subject: [PATCH] Fix export --- .../ASR/zipformer/export-onnx-streaming.py | 15 ++++++----- egs/librispeech/ASR/zipformer/export-onnx.py | 7 ++--- egs/librispeech/ASR/zipformer/export.py | 27 ++++++++++++++++--- .../ASR/zipformer/jit_pretrained.py | 4 +-- .../ASR/zipformer/jit_pretrained_streaming.py | 4 +-- .../zipformer/onnx_pretrained-streaming.py | 4 +-- .../ASR/zipformer/onnx_pretrained.py | 4 +-- egs/librispeech/ASR/zipformer/pretrained.py | 11 ++++---- egs/wenetspeech/ASR/zipformer/onnx_decode.py | 18 ++++++------- 9 files changed, 60 insertions(+), 34 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index a2e82f162..80dc19b37 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -57,9 +57,9 @@ whose value is "64,128,256,-1". It will generate the following 3 files inside $repo/exp: - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx + - encoder-epoch-99-avg-1-chunk-16-left-64.onnx + - decoder-epoch-99-avg-1-chunk-16-left-64.onnx + - joiner-epoch-99-avg-1-chunk-16-left-64.onnx See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. """ @@ -74,6 +74,7 @@ import onnx import torch import torch.nn as nn from decoder import Decoder +from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -585,9 +586,9 @@ def main(): logging.info(f"device: {device}") - symbol_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.vocab_size = len(symbol_table) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) @@ -706,6 +707,8 @@ def main(): suffix = f"epoch-{params.epoch}" suffix += f"-avg-{params.avg}" + suffix += f"-chunk-{params.chunk_size}" + suffix += f"-left-{params.left_context_frames}" opset_version = 13 diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index dfbbb0a02..1bc10c896 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -71,6 +71,7 @@ import onnx import torch import torch.nn as nn from decoder import Decoder +from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -433,9 +434,9 @@ def main(): logging.info(f"device: {device}") - symbol_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.vocab_size = len(symbol_table) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index da1942254..4a48d5bad 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -160,6 +160,7 @@ with the following commands: import argparse import logging +import re from pathlib import Path from typing import List, Tuple @@ -178,6 +179,26 @@ from icefall.checkpoint import ( from icefall.utils import make_pad_mask, str2bool +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -402,9 +423,9 @@ def main(): logging.info(f"device: {device}") - symbol_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.vocab_size = len(symbol_table) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py index 9e280a99f..a41fbc1c9 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py @@ -256,12 +256,12 @@ def main(): s = "\n" - symbol_table = k2.SymbolTable.from_file(args.tokens) + token_table = k2.SymbolTable.from_file(args.tokens) def token_ids_to_words(token_ids: List[int]) -> str: text = "" for i in token_ids: - text += symbol_table[i] + text += token_table[i] return text.replace("▁", " ").strip() for filename, hyp in zip(args.sound_files, hyps): diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py index a6822f3d8..d4ceacefd 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py @@ -190,7 +190,7 @@ def main(): decoder = model.decoder joiner = model.joiner - symbol_table = k2.SymbolTable.from_file(args.tokens) + token_table = k2.SymbolTable.from_file(args.tokens) context_size = decoder.context_size logging.info("Constructing Fbank computer") @@ -252,7 +252,7 @@ def main(): text = "" for i in hyp[context_size:]: - text += symbol_table[i] + text += token_table[i] text = text.replace("▁", " ").strip() logging.info(args.sound_file) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 273f883df..2ce4506a8 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -524,11 +524,11 @@ def main(): hyp, ) - symbol_table = k2.SymbolTable.from_file(args.tokens) + token_table = k2.SymbolTable.from_file(args.tokens) text = "" for i in hyp[context_size:]: - text += symbol_table[i] + text += token_table[i] text = text.replace("▁", " ").strip() logging.info(args.sound_file) diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index f5cb8decd..b821c4e19 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -396,12 +396,12 @@ def main(): ) s = "\n" - symbol_table = k2.SymbolTable.from_file(args.tokens) + token_table = k2.SymbolTable.from_file(args.tokens) def token_ids_to_words(token_ids: List[int]) -> str: text = "" for i in token_ids: - text += symbol_table[i] + text += token_table[i] return text.replace("▁", " ").strip() for filename, hyp in zip(args.sound_files, hyps): diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 3a119763f..3104b6084 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -122,6 +122,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) +from export import num_tokens from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params @@ -262,11 +263,11 @@ def main(): params.update(vars(args)) - symbol_table = k2.SymbolTable.from_file(params.tokens) + token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = symbol_table[""] - params.unk_id = symbol_table[""] - params.vocab_size = len(symbol_table) + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(f"{params}") @@ -328,7 +329,7 @@ def main(): def token_ids_to_words(token_ids: List[int]) -> str: text = "" for i in token_ids: - text += symbol_table[i] + text += token_table[i] return text.replace("▁", " ").strip() if params.method == "fast_beam_search": diff --git a/egs/wenetspeech/ASR/zipformer/onnx_decode.py b/egs/wenetspeech/ASR/zipformer/onnx_decode.py index 5fcd30147..ed5f6db08 100755 --- a/egs/wenetspeech/ASR/zipformer/onnx_decode.py +++ b/egs/wenetspeech/ASR/zipformer/onnx_decode.py @@ -133,7 +133,7 @@ def get_parser(): def decode_one_batch( - model: OnnxModel, symbol_table: k2.SymbolTable, batch: dict + model: OnnxModel, token_table: k2.SymbolTable, batch: dict ) -> List[List[str]]: """Decode one batch and return the result. Currently it only greedy_search is supported. @@ -141,7 +141,7 @@ def decode_one_batch( Args: model: The neural model. - symbol_table: + token_table: Mapping ids to tokens. batch: It is the return value from iterating @@ -164,14 +164,14 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens ) - hyps = [[symbol_table[h] for h in hyp] for hyp in hyps] + hyps = [[token_table[h] for h in hyp] for hyp in hyps] return hyps def decode_dataset( dl: torch.utils.data.DataLoader, model: nn.Module, - symbol_table: k2.SymbolTable, + token_table: k2.SymbolTable, ) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: """Decode dataset. @@ -180,7 +180,7 @@ def decode_dataset( PyTorch's dataloader containing the dataset to decode. model: The neural model. - symbol_table: + token_table: Mapping ids to tokens. Returns: @@ -206,7 +206,7 @@ def decode_dataset( cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - hyps = decode_one_batch(model=model, symbol_table=symbol_table, batch=batch) + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) this_batch = [] assert len(hyps) == len(texts) @@ -270,8 +270,8 @@ def main(): device = torch.device("cpu") logging.info(f"Device: {device}") - symbol_table = k2.SymbolTable.from_file(args.tokens) - assert symbol_table[0] == "" + token_table = k2.SymbolTable.from_file(args.tokens) + assert token_table[0] == "" logging.info(vars(args)) @@ -313,7 +313,7 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() results, total_duration = decode_dataset( - dl=test_dl, model=model, symbol_table=symbol_table + dl=test_dl, model=model, token_table=token_table ) end_time = time.time() elapsed_seconds = end_time - start_time