mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
use tokens.txt to replace bpe.model
This commit is contained in:
parent
af2d83d8d7
commit
0945c5b379
@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -28,7 +27,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./pruned_transducer_stateless8/export-onnx.py \
|
./pruned_transducer_stateless8/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -50,8 +49,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -66,7 +65,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -123,10 +122,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -412,12 +411,10 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
|
||||||
sp.load(params.bpe_model)
|
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user