mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Add fast_beam_search_with_nbest_rescoring in decode
This commit is contained in:
parent
2456307acb
commit
fd4deeab95
@ -111,6 +111,7 @@ from beam_search import (
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
fast_beam_search_with_nbest_rescoring,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
@ -312,6 +313,35 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""Softmax temperature.
|
||||
The output of the model is (logits / temperature).log_softmax().
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=Path,
|
||||
default=Path("./data/lm"),
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search_with_nbest_rescoring.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-txt",
|
||||
type=Path,
|
||||
default=Path("./data/lang_bpe_500/words.txt"),
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search_with_nbest_rescoring.
|
||||
It is the word table.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -324,6 +354,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -352,6 +383,11 @@ def decode_one_batch(
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
G:
|
||||
Optional. Used only when decoding method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
|
||||
or fast_beam_search_with_nbest_rescoring.
|
||||
It an FsaVec containing an acceptor.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -397,6 +433,7 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -411,6 +448,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
@ -425,6 +463,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -440,6 +479,7 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -460,9 +500,32 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
|
||||
ngram_lm_scale_list += [0.01, 0.02, 0.05]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
|
||||
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
|
||||
hyp_tokens = fast_beam_search_with_nbest_rescoring(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
ngram_lm_scale_list=ngram_lm_scale_list,
|
||||
num_paths=params.num_paths,
|
||||
G=G,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
use_double_scores=True,
|
||||
nbest_scale=params.nbest_scale,
|
||||
temperature=params.temperature,
|
||||
)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -496,6 +559,7 @@ def decode_one_batch(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
f"temperature_{params.temperature}"
|
||||
): hyps
|
||||
}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
@ -504,8 +568,23 @@ def decode_one_batch(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
f"temperature_{params.temperature}"
|
||||
): hyps
|
||||
}
|
||||
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||
prefix = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}_"
|
||||
f"num_paths_{params.num_paths}_"
|
||||
f"nbest_scale_{params.nbest_scale}_"
|
||||
f"temperature_{params.temperature}_"
|
||||
)
|
||||
ans: Dict[str, List[List[str]]] = {}
|
||||
for key, hyp in hyp_tokens.items():
|
||||
t: List[str] = sp.decode(hyp)
|
||||
ans[prefix + key] = [s.split() for s in t]
|
||||
return ans
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -518,7 +597,12 @@ def decode_one_batch(
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
return {
|
||||
(
|
||||
f"beam_size_{params.beam_size}_"
|
||||
f"temperature_{params.temperature}"
|
||||
): hyps
|
||||
}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -528,6 +612,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -546,6 +631,11 @@ def decode_dataset(
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
G:
|
||||
Optional. Used only when decoding method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
|
||||
or fast_beam_search_with_nbest_rescoring.
|
||||
It's an FsaVec containing an acceptor.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -576,6 +666,7 @@ def decode_dataset(
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -642,6 +733,71 @@ def save_results(
|
||||
logging.info(s)
|
||||
|
||||
|
||||
def load_ngram_LM(
|
||||
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
|
||||
) -> k2.Fsa:
|
||||
"""Read a ngram model from the given directory.
|
||||
Args:
|
||||
lm_dir:
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
word_table:
|
||||
The word table mapping words to IDs and vice versa.
|
||||
device:
|
||||
The resulting FSA will be moved to this device.
|
||||
Returns:
|
||||
Return an FsaVec containing a single acceptor.
|
||||
"""
|
||||
lm_dir = Path(lm_dir)
|
||||
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
|
||||
|
||||
pt_file = lm_dir / "G_4_gram.pt"
|
||||
|
||||
if pt_file.is_file():
|
||||
logging.info(f"Loading pre-compiled {pt_file}")
|
||||
d = torch.load(pt_file, map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
return G
|
||||
|
||||
txt_file = lm_dir / "G_4_gram.fst.txt"
|
||||
|
||||
assert txt_file.is_file(), f"{txt_file} does not exist"
|
||||
logging.info(f"Loading {txt_file}")
|
||||
logging.warning("It may take 8 minutes (Will be cached for later use).")
|
||||
with open(txt_file) as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# Now G is an acceptor
|
||||
|
||||
first_word_disambig_id = word_table["#0"]
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set G.properties to None
|
||||
G.__dict__["_properties"] = None
|
||||
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
logging.info(f"Saving to {pt_file} for later use")
|
||||
torch.save(G.as_dict(), pt_file)
|
||||
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
return G
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -660,6 +816,7 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"fast_beam_search_with_nbest_rescoring",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -760,6 +917,19 @@ def main():
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||
logging.info(f"Loading word symbol table from {params.words_txt}")
|
||||
word_table = k2.SymbolTable.from_file(params.words_txt)
|
||||
|
||||
G = load_ngram_LM(
|
||||
lm_dir=params.lm_dir,
|
||||
word_table=word_table,
|
||||
device=device,
|
||||
)
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
logging.info(f"G properties_str: {G.properties_str}")
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(
|
||||
@ -792,6 +962,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
Loading…
x
Reference in New Issue
Block a user