Give an informative message when users provide an unsupported decoding method (#77)

This commit is contained in:
Fangjun Kuang 2021-10-14 16:20:35 +08:00 committed by GitHub
parent 39bc8cae94
commit 5016ee3c95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,23 +20,23 @@
import argparse import argparse
import logging import logging
import math import math
import sentencepiece as spm
from typing import List from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_texts
@ -58,7 +58,7 @@ def get_parser():
"--lang-dir", "--lang-dir",
type=str, type=str,
required=True, required=True,
help="Path to lang bpe dir.", help="Path to lang dir.",
) )
parser.add_argument( parser.add_argument(
@ -142,7 +142,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--sos-id", "--sos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -152,7 +152,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--eos-id", "--eos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -285,9 +285,8 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
try:
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Building CTC topology") logging.info("Use CTC decoding")
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
H = k2.ctc_topo( H = k2.ctc_topo(
@ -296,7 +295,6 @@ def main():
device=device, device=device,
) )
logging.info("Loading BPE model")
bpe_model = spm.SentencePieceProcessor() bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.lang_dir + "/bpe.model") bpe_model.load(params.lang_dir + "/bpe.model")
@ -311,15 +309,13 @@ def main():
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
logging.info("Use CTC decoding")
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids) hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps] hyps = [s.split() for s in hyps]
elif params.method in [
if params.method in [
"1best", "1best",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
@ -371,9 +367,7 @@ def main():
) )
best_path = next(iter(best_path_dict.values())) best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder": elif params.method == "attention-decoder":
logging.info( logging.info("Use HLG + LM rescoring + attention decoder rescoring")
"Use HLG + LM rescoring + attention decoder rescoring"
)
rescored_lattice = rescore_with_whole_lattice( rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
) )
@ -396,6 +390,8 @@ def main():
params.lang_dir + "/words.txt" params.lang_dir + "/words.txt"
) )
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
@ -405,9 +401,6 @@ def main():
logging.info("Decoding Done") logging.info("Decoding Done")
except Exception:
raise ValueError("Please use a supported decoding method.")
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = (