mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
minor fixes
This commit is contained in:
parent
b9bbdfaadc
commit
9644c1722a
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx-zh.py \
|
./lstm_transducer_stateless2/export-onnx-zh.py \
|
||||||
--lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \
|
--tokens ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char/tokens.txt \
|
||||||
--use-averaged-model 1 \
|
--use-averaged-model 1 \
|
||||||
--epoch 11 \
|
--epoch 11 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -55,6 +55,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -70,8 +71,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
from icefall.utils import setup_logger, str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -128,10 +128,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_char",
|
default="data/lang_char/tokens.txt",
|
||||||
help="The lang dir",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -441,9 +441,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.blank_id = 0
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ def main():
|
|||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
d_model=params.attention_dim,
|
d_model=params.attention_dim,
|
||||||
num_classes=num_classes,
|
num_classes=params.vocab_size,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
|
@ -182,7 +182,7 @@ def main():
|
|||||||
|
|
||||||
model = Conformer(
|
model = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
num_classes=num_classes,
|
num_classes=params.vocab_size,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.dim_model,
|
d_model=params.dim_model,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user