minor fixes

This commit is contained in:
jinzr 2024-01-27 03:21:26 +08:00
parent b9bbdfaadc
commit 9644c1722a
3 changed files with 11 additions and 11 deletions

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx-zh.py \ ./lstm_transducer_stateless2/export-onnx-zh.py \
--lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ --tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \
--use-averaged-model 1 \ --use-averaged-model 1 \
--epoch 11 \ --epoch 11 \
--avg 1 \ --avg 1 \
@ -55,6 +55,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import k2
import onnx import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -70,8 +71,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, setup_logger, str2bool
from icefall.utils import setup_logger, str2bool
def get_parser(): def get_parser():
@ -128,10 +128,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", default="data/lang_char/tokens.txt",
help="The lang dir", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -441,9 +441,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = 0 params.blank_id = token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)

View File

@ -118,7 +118,7 @@ def main():
num_features=params.feature_dim, num_features=params.feature_dim,
nhead=params.nhead, nhead=params.nhead,
d_model=params.attention_dim, d_model=params.attention_dim,
num_classes=num_classes, num_classes=params.vocab_size,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False, vgg_frontend=False,

View File

@ -182,7 +182,7 @@ def main():
model = Conformer( model = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
num_classes=num_classes, num_classes=params.vocab_size,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.dim_model, d_model=params.dim_model,
nhead=params.nhead, nhead=params.nhead,