mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
support --tokens in conformer_ctc/export.py
This commit is contained in:
parent
5a877da3a0
commit
ad4af5c899
@ -23,12 +23,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
from icefall.utils import AttributeDict, str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -63,11 +63,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_char",
|
required=True,
|
||||||
help="""It contains language related input files such as "lexicon.txt"
|
help="Path to the tokens.txt.",
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -98,16 +97,16 @@ def get_params() -> AttributeDict:
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
logging.info(params)
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
max_token_id = max(lexicon.tokens)
|
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
logging.info(params)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user