mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
minor fixes
This commit is contained in:
parent
14f0cb5977
commit
bf6fb9f0e2
@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do
|
|||||||
--model-filename $repo/exp/jit_trace.pt \
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
--method $m \
|
--method $m \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
|
2
.github/scripts/test-ncnn-export.sh
vendored
2
.github/scripts/test-ncnn-export.sh
vendored
@ -60,7 +60,7 @@ log "Export via torch.jit.trace()"
|
|||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
\
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
|
@ -139,8 +139,8 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||||
@ -154,7 +154,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -211,10 +211,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -675,12 +675,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user