This commit is contained in:
jinzr 2023-10-31 10:16:46 +08:00
parent edb2bd56b2
commit fde8a2ff65
2 changed files with 19 additions and 33 deletions

View File

@ -76,17 +76,17 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon
from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -143,13 +143,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -189,17 +186,9 @@ def main():
logging.info(f"device: {device}")
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
else:
assert "lang_phone" in str(params.lang_dir)
phone_lexicon = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(phone_lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -23,6 +23,7 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
@ -33,8 +34,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -90,11 +90,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""It contains language related input files such as "lexicon.txt"
""",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -113,17 +112,15 @@ def get_parser():
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():