update the conformer_ctc recipe to replace lang-dir with tokens

This commit is contained in:
jinzr 2023-07-13 14:19:14 +08:00
parent 208c30c160
commit 40af5f2828
2 changed files with 27 additions and 23 deletions

View File

@ -23,12 +23,13 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from conformer import Conformer from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser(): def get_parser():
@ -63,11 +64,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_bpe_500", help="Path to the tokens.txt.",
help="""It contains language related input files such as "lexicon.txt"
""",
) )
parser.add_argument( parser.add_argument(
@ -98,16 +97,16 @@ def get_params() -> AttributeDict:
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
logging.info(params) logging.info(params)
lexicon = Lexicon(params.lang_dir) # Load tokens.txt here
max_token_id = max(lexicon.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = max_token_id + 1 # +1 for the blank
num_classes = num_tokens(token_table) + 1 # +1 for the blank
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -24,7 +24,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer import Conformer
@ -70,11 +69,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model. help="Path to the tokens.txt.",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -257,6 +254,9 @@ def main():
params.update(vars(args)) params.update(vars(args))
logging.info(f"{params}") logging.info(f"{params}")
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
@ -297,6 +297,7 @@ def main():
waves = [w.to(device) for w in waves] waves = [w.to(device) for w in waves]
logging.info("Decoding started") logging.info("Decoding started")
hyps = []
features = fbank(waves) features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
@ -311,10 +312,14 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1 max_token_id = params.num_classes - 1
H = k2.ctc_topo( H = k2.ctc_topo(
@ -337,9 +342,9 @@ def main():
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) hyp_tokens = get_texts(best_path)
hyps = bpe_model.decode(token_ids) for hyp in hyp_tokens:
hyps = [s.split() for s in hyps] hyps.append(token_ids_to_words(hyp))
elif params.method in [ elif params.method in [
"1best", "1best",
"whole-lattice-rescoring", "whole-lattice-rescoring",
@ -408,16 +413,16 @@ def main():
) )
best_path = next(iter(best_path_dict.values())) best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file) word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")