Add fast_beam_search_with_nbest_rescoring in decode

This commit is contained in:
Erwan 2022-07-07 15:20:01 +02:00
parent 2456307acb
commit fd4deeab95

View File

@ -111,6 +111,7 @@ from beam_search import (
fast_beam_search_nbest_LG, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
fast_beam_search_with_nbest_rescoring,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -312,6 +313,35 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Softmax temperature.
The output of the model is (logits / temperature).log_softmax().
""",
)
parser.add_argument(
"--lm-dir",
type=Path,
default=Path("./data/lm"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--words-txt",
type=Path,
default=Path("./data/lang_bpe_500/words.txt"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It is the word table.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -324,6 +354,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -352,6 +383,11 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -397,6 +433,7 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -411,6 +448,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
@ -425,6 +463,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -440,6 +479,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -460,9 +500,32 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -496,6 +559,7 @@ def decode_one_batch(
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
@ -504,8 +568,23 @@ def decode_one_batch(
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
prefix = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}_"
)
ans: Dict[str, List[List[str]]] = {}
for key, hyp in hyp_tokens.items():
t: List[str] = sp.decode(hyp)
ans[prefix + key] = [s.split() for s in t]
return ans
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -518,7 +597,12 @@ def decode_one_batch(
return {key: hyps} return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {
(
f"beam_size_{params.beam_size}_"
f"temperature_{params.temperature}"
): hyps
}
def decode_dataset( def decode_dataset(
@ -528,6 +612,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -546,6 +631,11 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It's an FsaVec containing an acceptor.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -576,6 +666,7 @@ def decode_dataset(
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
G=G,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -642,6 +733,71 @@ def save_results(
logging.info(s) logging.info(s)
def load_ngram_LM(
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
) -> k2.Fsa:
"""Read a ngram model from the given directory.
Args:
lm_dir:
It should contain either G_4_gram.pt or G_4_gram.fst.txt
word_table:
The word table mapping words to IDs and vice versa.
device:
The resulting FSA will be moved to this device.
Returns:
Return an FsaVec containing a single acceptor.
"""
lm_dir = Path(lm_dir)
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
pt_file = lm_dir / "G_4_gram.pt"
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device)
G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
txt_file = lm_dir / "G_4_gram.fst.txt"
assert txt_file.is_file(), f"{txt_file} does not exist"
logging.info(f"Loading {txt_file}")
logging.warning("It may take 8 minutes (Will be cached for later use).")
with open(txt_file) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# Now G is an acceptor
first_word_disambig_id = word_table["#0"]
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
logging.info(f"Saving to {pt_file} for later use")
torch.save(G.as_dict(), pt_file)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -660,6 +816,7 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"fast_beam_search_with_nbest_rescoring",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -760,6 +917,19 @@ def main():
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
logging.info(f"Loading word symbol table from {params.words_txt}")
word_table = k2.SymbolTable.from_file(params.words_txt)
G = load_ngram_LM(
lm_dir=params.lm_dir,
word_table=word_table,
device=device,
)
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
logging.info(f"G properties_str: {G.properties_str}")
else: else:
word_table = None word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(
@ -792,6 +962,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
G=G,
) )
save_results( save_results(