minor fixes

This commit is contained in:
jinzr 2023-08-11 21:05:36 +08:00
parent 14f0cb5977
commit bf6fb9f0e2
3 changed files with 13 additions and 11 deletions

View File

@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do
--model-filename $repo/exp/jit_trace.pt \ --model-filename $repo/exp/jit_trace.pt \
--words-file $repo/data/lang_bpe_500/words.txt \ --words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \ --HLG $repo/data/lang_bpe_500/HLG.pt \
--tokens $repo/data/lang_bpe_500/tokens.txt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \
--G $repo/data/lm/G_4_gram.pt \ --G $repo/data/lm/G_4_gram.pt \
--method $m \ --method $m \
--sample-rate 16000 \ --sample-rate 16000 \

View File

@ -60,7 +60,7 @@ log "Export via torch.jit.trace()"
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
\ --tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 12 \ --num-encoder-layers 12 \
--chunk-length 32 \ --chunk-length 32 \
--cnn-module-kernel 31 \ --cnn-module-kernel 31 \

View File

@ -139,8 +139,8 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import onnxruntime import onnxruntime
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
@ -154,7 +154,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -211,10 +211,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -675,12 +675,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)