Fixes after refactoring.

This commit is contained in:
Fangjun Kuang 2021-09-18 16:32:35 +08:00
parent 8623983bb7
commit e062c1b5cf
4 changed files with 114 additions and 90 deletions

View File

@ -30,20 +30,14 @@ from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import get_lattice
from icefall.decode import (
one_best_decoding, # done
rescore_with_attention_decoder, # done
rescore_with_n_best_list, # done
rescore_with_whole_lattice, # done
nbest_oracle, # done
)
from icefall.decode2 import (
get_lattice,
nbest_decoding,
nbest_oracle as nbest_oracle2,
rescore_with_n_best_list as rescore_with_n_best_list2,
rescore_with_whole_lattice as rescore_with_whole_lattice2,
rescore_with_attention_decoder as rescore_with_attention_decoder2,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -250,10 +244,8 @@ def decode_one_batch(
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
if True:
# TODO: delete the `else` branch
best_path = nbest_oracle2(
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
@ -265,14 +257,6 @@ def decode_one_batch(
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
return {key: hyps}
else:
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
scale=params.lattice_score_scale,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
@ -304,32 +288,14 @@ def decode_one_batch(
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
if True:
# TODO: remove the "else" branch
best_path_dict = rescore_with_n_best_list2(
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
if True:
# TODO: remove "else" branch
best_path_dict = rescore_with_whole_lattice2(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
@ -338,21 +304,13 @@ def decode_one_batch(
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
if True:
best_path_dict = rescore_with_attention_decoder2(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
@ -361,7 +319,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"

View File

@ -336,7 +336,7 @@ def main():
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
lattice_score_scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)

View File

@ -229,6 +229,7 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
@ -247,10 +248,13 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale,
)
else:
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
ans = dict()

View File

@ -64,6 +64,68 @@ def _intersect_device(
return k2.cat(ans)
def get_lattice(
nnet_output: torch.Tensor,
HLG: k2.Fsa,
supervision_segments: torch.Tensor,
search_beam: float,
output_beam: float,
min_active_states: int,
max_active_states: int,
subsampling_factor: int = 1,
) -> k2.Fsa:
"""Get the decoding lattice from a decoding graph and neural
network output.
Args:
nnet_output:
It is the output of a neural model of shape `[N, T, C]`.
HLG:
An Fsa, the decoding graph. See also `compile_HLG.py`.
supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0
is the `sequence_index` indicating which sequence this segment
comes from; column 1 specifies the `start_frame` of this segment
within the sequence; column 2 contains the `duration` of this
segment.
search_beam:
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be modified by
`min_active_states` and `max_active_states`.
output_beam:
Beam to prune output, similar to lattice-beam in Kaldi. Relative
to best path of output.
min_active_states:
Minimum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to have fewer than this number active.
Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any given
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
subsampling_factor:
The subsampling factor of the model.
Returns:
A lattice containing the decoding result.
"""
dense_fsa_vec = k2.DenseFsaVec(
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
)
lattice = k2.intersect_dense_pruned(
HLG,
dense_fsa_vec,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states,
)
return lattice
# TODO(fangjun): Use Kangwei's C++ implementation that also
# supports List[List[int]]
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa: