Give an informative message when users provide an unsupported decoding method (#77)

This commit is contained in:
Fangjun Kuang 2021-10-14 16:20:35 +08:00 committed by GitHub
parent 39bc8cae94
commit 5016ee3c95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,23 +20,23 @@
import argparse
import logging
import math
import sentencepiece as spm
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, get_texts
@ -58,7 +58,7 @@ def get_parser():
"--lang-dir",
type=str,
required=True,
help="Path to lang bpe dir.",
help="Path to lang dir.",
)
parser.add_argument(
@ -142,7 +142,7 @@ def get_parser():
parser.add_argument(
"--sos-id",
type=float,
type=int,
default=1,
help="""
Used only when method is attention-decoder.
@ -152,7 +152,7 @@ def get_parser():
parser.add_argument(
"--eos-id",
type=float,
type=int,
default=1,
help="""
Used only when method is attention-decoder.
@ -285,9 +285,8 @@ def main():
dtype=torch.int32,
)
try:
if params.method == "ctc-decoding":
logging.info("Building CTC topology")
logging.info("Use CTC decoding")
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
H = k2.ctc_topo(
@ -296,7 +295,6 @@ def main():
device=device,
)
logging.info("Loading BPE model")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.lang_dir + "/bpe.model")
@ -311,15 +309,13 @@ def main():
subsampling_factor=params.subsampling_factor,
)
logging.info("Use CTC decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
if params.method in [
elif params.method in [
"1best",
"whole-lattice-rescoring",
"attention-decoder",
@ -371,9 +367,7 @@ def main():
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info(
"Use HLG + LM rescoring + attention decoder rescoring"
)
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
@ -396,6 +390,8 @@ def main():
params.lang_dir + "/words.txt"
)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
@ -405,9 +401,6 @@ def main():
logging.info("Decoding Done")
except Exception:
raise ValueError("Please use a supported decoding method.")
if __name__ == "__main__":
formatter = (