mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
update the conformer_ctc
recipe to replace lang-dir with tokens
This commit is contained in:
parent
208c30c160
commit
40af5f2828
@ -23,12 +23,13 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -63,11 +64,9 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -98,16 +97,16 @@ def get_params() -> AttributeDict:
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
|
@ -24,7 +24,6 @@ from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
@ -70,11 +69,9 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
help="Path to the tokens.txt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -257,6 +254,9 @@ def main():
|
||||
params.update(vars(args))
|
||||
logging.info(f"{params}")
|
||||
|
||||
# Load tokens.txt here
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
@ -297,6 +297,7 @@ def main():
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
hyps = []
|
||||
features = fbank(waves)
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
@ -311,10 +312,14 @@ def main():
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||
text = ""
|
||||
for i in token_ids:
|
||||
text += token_table[i]
|
||||
return text.replace("▁", " ").strip()
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
@ -337,9 +342,9 @@ def main():
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
hyp_tokens = get_texts(best_path)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append(token_ids_to_words(hyp))
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"whole-lattice-rescoring",
|
||||
@ -408,16 +413,16 @@ def main():
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
hyp_tokens = get_texts(best_path)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
s += f"{filename}:\n{hyp}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
Loading…
x
Reference in New Issue
Block a user