minor fixes

This commit is contained in:
jinzr 2024-01-27 03:21:26 +08:00
parent b9bbdfaadc
commit 9644c1722a
3 changed files with 11 additions and 11 deletions

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx-zh.py \
--lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \
--tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \
--use-averaged-model 1 \
--epoch 11 \
--avg 1 \
@ -55,6 +55,7 @@ import logging
from pathlib import Path
from typing import Dict, Optional, Tuple
import k2
import onnx
import torch
import torch.nn as nn
@ -70,8 +71,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -128,10 +128,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -441,9 +441,9 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)

View File

@ -118,7 +118,7 @@ def main():
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
num_classes=params.vocab_size,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,

View File

@ -182,7 +182,7 @@ def main():
model = Conformer(
num_features=params.feature_dim,
num_classes=num_classes,
num_classes=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.dim_model,
nhead=params.nhead,