From 9644c1722a3d812701fa6e10535cc63469d488e3 Mon Sep 17 00:00:00 2001 From: jinzr Date: Sat, 27 Jan 2024 03:21:26 +0800 Subject: [PATCH] minor fixes --- .../export-onnx-zh.py | 18 +++++++++--------- egs/swbd/ASR/conformer_ctc/export.py | 2 +- egs/tedlium3/ASR/conformer_ctc2/export.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index 2a52e2eec..1ce770128 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx-zh.py \ - --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ + --tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \ --use-averaged-model 1 \ --epoch 11 \ --avg 1 \ @@ -55,6 +55,7 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -70,8 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -441,9 +441,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py index 7df5a8bfa..44b2e95d6 100755 --- a/egs/swbd/ASR/conformer_ctc/export.py +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -118,7 +118,7 @@ def main(): num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim, - num_classes=num_classes, + num_classes=params.vocab_size, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py index 13188fca1..b5bf911c2 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/export.py +++ b/egs/tedlium3/ASR/conformer_ctc2/export.py @@ -182,7 +182,7 @@ def main(): model = Conformer( num_features=params.feature_dim, - num_classes=num_classes, + num_classes=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.dim_model, nhead=params.nhead,