From b9bbdfaadc68419c1646b013ac3f132054573029 Mon Sep 17 00:00:00 2001 From: jinzr Date: Sat, 27 Jan 2024 03:18:07 +0800 Subject: [PATCH] Comply to issue #1149 https://github.com/k2-fsa/icefall/issues/1149 --- .../pruned_transducer_stateless2/export.py | 19 ++++++++------- .../pruned_transducer_stateless3/export.py | 19 ++++++++------- .../export-onnx.py | 21 ++++++++--------- .../ASR/transducer_stateless/export.py | 19 ++++++++------- .../transducer_stateless_modified-2/export.py | 18 +++++++-------- .../transducer_stateless_modified/export.py | 19 ++++++++------- .../pruned_transducer_stateless5/export.py | 20 ++++++++-------- .../pruned_transducer_stateless5/export.py | 19 +++++++-------- .../pruned_transducer_stateless2/export.py | 19 ++++++++------- .../pruned_transducer_stateless7/export.py | 23 ++++++++----------- egs/swbd/ASR/conformer_ctc/export.py | 17 +++++++------- egs/tedlium3/ASR/conformer_ctc2/export.py | 16 ++++++------- .../export-onnx-streaming.py | 19 +++++++-------- .../export-onnx.py | 18 +++++++-------- 14 files changed, 126 insertions(+), 140 deletions(-) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/export.py b/egs/aishell/ASR/pruned_transducer_stateless2/export.py index 2ce5cfe69..c2dc0d5f3 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/export.py @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -136,10 +136,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 723414167..2248c7a08 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -47,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -153,10 +153,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 params.datatang_prob = 0 logging.info(params) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py index 39d988cd0..4981fb71a 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py @@ -49,14 +49,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder2 import Decoder +from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled -from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer from icefall.checkpoint import ( @@ -65,8 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,12 +122,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -404,9 +401,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 01de5d772..bfd0ecb0c 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -23,7 +23,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --lang-dir data/lang_char \ + --tokens data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,8 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -192,10 +192,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index c1081c32b..4f2c71d18 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -46,6 +46,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -99,10 +100,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -190,10 +191,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 3e14ad69c..487748947 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -46,6 +46,7 @@ import argparse import logging from pathlib import Path +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -55,8 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -99,10 +99,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default=Path("data/lang_char"), - help="The lang dir", + "--tokens", + type=str, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -190,10 +190,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py index 8a5be94d0..c92c7ab83 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir data/lang_char + --tokens ./data/lang_char/tokens.txt \ --epoch 25 \ --avg 5 @@ -48,6 +48,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -154,10 +154,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.unk_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index bf9856c60..246820833 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -48,6 +48,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -57,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,13 +115,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -157,9 +154,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 8e5cc6075..5dc73c52b 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -20,7 +20,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ + --tokens ./data/lang_char/tokens.txt \ --epoch 29 \ --avg 18 @@ -45,12 +45,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -85,10 +85,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -122,10 +122,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py index 23a88dd29..8bafaef44 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_char/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,9 +86,8 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch -import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -199,10 +197,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py index 1bb6277ad..7df5a8bfa 100755 --- a/egs/swbd/ASR/conformer_ctc/export.py +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -23,12 +23,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +63,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -105,9 +104,9 @@ def main(): logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py index 009bea230..13188fca1 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/export.py +++ b/egs/tedlium3/ASR/conformer_ctc2/export.py @@ -45,6 +45,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from scaling_converter import convert_scaled_to_non_scaled @@ -56,8 +57,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser() -> argparse.ArgumentParser: @@ -118,10 +118,10 @@ def get_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,9 +166,9 @@ def main(): params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 921766ad4..30068d01a 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -from icefall.lexicon import Lexicon import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,8 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.lexicon import Lexicon +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +132,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -490,9 +491,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py index 037c7adf1..1c9eb8648 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,6 +55,7 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -70,8 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -417,9 +417,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params)