update the conformer_ctc recipe to replace lang-dir with tokens

This commit is contained in:
jinzr 2023-07-13 14:19:14 +08:00
parent 208c30c160
commit 40af5f2828
2 changed files with 27 additions and 23 deletions

View File

@ -23,12 +23,13 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -63,11 +64,9 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""It contains language related input files such as "lexicon.txt"
""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -98,16 +97,16 @@ def get_params() -> AttributeDict:
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = num_tokens(token_table) + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():

View File

@ -24,7 +24,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from conformer import Conformer
@ -70,11 +69,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -257,6 +254,9 @@ def main():
params.update(vars(args))
logging.info(f"{params}")
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
@ -297,6 +297,7 @@ def main():
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
hyps = []
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
@ -311,10 +312,14 @@ def main():
dtype=torch.int32,
)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
H = k2.ctc_topo(
@ -337,9 +342,9 @@ def main():
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method in [
"1best",
"whole-lattice-rescoring",
@ -408,16 +413,16 @@ def main():
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")