Give an informative message when users provide an unsupported decoding method (#77)

This commit is contained in:
Fangjun Kuang 2021-10-14 16:20:35 +08:00 committed by GitHub
parent 39bc8cae94
commit 5016ee3c95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,23 +20,23 @@
import argparse import argparse
import logging import logging
import math import math
import sentencepiece as spm
from typing import List from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.lexicon import Lexicon
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
rescore_with_whole_lattice, rescore_with_whole_lattice,
) )
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_texts
@ -58,7 +58,7 @@ def get_parser():
"--lang-dir", "--lang-dir",
type=str, type=str,
required=True, required=True,
help="Path to lang bpe dir.", help="Path to lang dir.",
) )
parser.add_argument( parser.add_argument(
@ -142,7 +142,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--sos-id", "--sos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -152,7 +152,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--eos-id", "--eos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -285,128 +285,121 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
try: if params.method == "ctc-decoding":
if params.method == "ctc-decoding": logging.info("Use CTC decoding")
logging.info("Building CTC topology") lexicon = Lexicon(params.lang_dir)
lexicon = Lexicon(params.lang_dir) max_token_id = max(lexicon.tokens)
max_token_id = max(lexicon.tokens) H = k2.ctc_topo(
H = k2.ctc_topo( max_token=max_token_id,
max_token=max_token_id, modified=False,
modified=False, device=device,
device=device, )
)
logging.info("Loading BPE model") bpe_model = spm.SentencePieceProcessor()
bpe_model = spm.SentencePieceProcessor() bpe_model.load(params.lang_dir + "/bpe.model")
bpe_model.load(params.lang_dir + "/bpe.model")
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
decoding_graph=H, decoding_graph=H,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
min_active_states=params.min_active_states, min_active_states=params.min_active_states,
max_active_states=params.max_active_states, max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
logging.info("Use CTC decoding") best_path = one_best_decoding(
best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores
lattice=lattice, use_double_scores=params.use_double_scores )
) token_ids = get_texts(best_path)
token_ids = get_texts(best_path) hyps = bpe_model.decode(token_ids)
hyps = bpe_model.decode(token_ids) hyps = [s.split() for s in hyps]
hyps = [s.split() for s in hyps] elif params.method in [
"1best",
"whole-lattice-rescoring",
"attention-decoder",
]:
logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt")
HLG = k2.Fsa.from_dict(
torch.load(params.lang_dir + "/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in [ if params.method in [
"1best",
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
]: ]:
logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") logging.info(f"Loading G from {params.G}")
HLG = k2.Fsa.from_dict( G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") # Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
) )
HLG = HLG.to(device) elif params.method == "whole-lattice-rescoring":
if not hasattr(HLG, "lm_scores"): logging.info("Use HLG decoding + LM rescoring")
# For whole-lattice-rescoring and attention-decoder best_path_dict = rescore_with_whole_lattice(
HLG.lm_scores = HLG.scores.clone() lattice=lattice,
G_with_epsilon_loops=G,
if params.method in [ lm_scale_list=[params.ngram_lm_scale],
"whole-lattice-rescoring",
"attention-decoder",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = G.to(device)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
) )
best_path = next(iter(best_path_dict.values()))
if params.method == "1best": elif params.method == "attention-decoder":
logging.info("Use HLG decoding") logging.info("Use HLG + LM rescoring + attention decoder rescoring")
best_path = one_best_decoding( rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info(
"Use HLG + LM rescoring + attention decoder rescoring"
)
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(
params.lang_dir + "/words.txt"
) )
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
s = "\n" hyps = get_texts(best_path)
for filename, hyp in zip(params.sound_files, hyps): word_sym_table = k2.SymbolTable.from_file(
words = " ".join(hyp) params.lang_dir + "/words.txt"
s += f"{filename}:\n{words}\n\n" )
logging.info(s) hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
logging.info("Decoding Done") s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
except Exception: logging.info("Decoding Done")
raise ValueError("Please use a supported decoding method.")
if __name__ == "__main__": if __name__ == "__main__":