mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
incorporate https://github.com/k2-fsa/icefall/pull/1162
This commit is contained in:
parent
edb2bd56b2
commit
fde8a2ff65
@ -76,17 +76,17 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import UniqLexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -143,13 +143,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="""The lang dir
|
help="Path to the tokens.txt",
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -189,17 +186,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
if "lang_bpe" in str(params.lang_dir):
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
sp = spm.SentencePieceProcessor()
|
params.blank_id = token_table["<blk>"]
|
||||||
sp.load(params.lang_dir + "/bpe.model")
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
else:
|
|
||||||
assert "lang_phone" in str(params.lang_dir)
|
|
||||||
phone_lexicon = UniqLexicon(params.lang_dir)
|
|
||||||
params.blank_id = 0
|
|
||||||
params.vocab_size = max(phone_lexicon.tokens) + 1
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
@ -33,8 +34,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -90,11 +90,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="""It contains language related input files such as "lexicon.txt"
|
help="Path to the tokens.txt",
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -113,17 +112,15 @@ def get_parser():
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
logging.info(params)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
logging.info(params)
|
||||||
max_token_id = max(lexicon.tokens)
|
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
params.vocab_size = num_classes
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user