mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 12:32:20 +00:00
Fixes after refactoring.
This commit is contained in:
parent
8623983bb7
commit
e062c1b5cf
@ -30,20 +30,14 @@ from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import get_lattice
|
||||
from icefall.decode import (
|
||||
one_best_decoding, # done
|
||||
rescore_with_attention_decoder, # done
|
||||
rescore_with_n_best_list, # done
|
||||
rescore_with_whole_lattice, # done
|
||||
nbest_oracle, # done
|
||||
)
|
||||
from icefall.decode2 import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle as nbest_oracle2,
|
||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
||||
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
||||
rescore_with_attention_decoder as rescore_with_attention_decoder2,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -250,29 +244,19 @@ def decode_one_batch(
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is slightly worse than that of rescored lattices.
|
||||
if True:
|
||||
# TODO: delete the `else` branch
|
||||
best_path = nbest_oracle2(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||
return {key: hyps}
|
||||
else:
|
||||
return nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
@ -304,65 +288,39 @@ def decode_one_batch(
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
if True:
|
||||
# TODO: remove the "else" branch
|
||||
best_path_dict = rescore_with_n_best_list2(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
if True:
|
||||
# TODO: remove "else" branch
|
||||
best_path_dict = rescore_with_whole_lattice2(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
if True:
|
||||
best_path_dict = rescore_with_attention_decoder2(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
|
@ -336,7 +336,7 @@ def main():
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
attention_scale=params.attention_decoder_scale,
|
||||
)
|
||||
|
@ -229,6 +229,7 @@ def decode_one_batch(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
key = f"no_rescore-{params.num_paths}"
|
||||
hyps = get_texts(best_path)
|
||||
@ -247,10 +248,13 @@ def decode_one_batch(
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
|
||||
ans = dict()
|
||||
|
@ -64,6 +64,68 @@ def _intersect_device(
|
||||
return k2.cat(ans)
|
||||
|
||||
|
||||
def get_lattice(
|
||||
nnet_output: torch.Tensor,
|
||||
HLG: k2.Fsa,
|
||||
supervision_segments: torch.Tensor,
|
||||
search_beam: float,
|
||||
output_beam: float,
|
||||
min_active_states: int,
|
||||
max_active_states: int,
|
||||
subsampling_factor: int = 1,
|
||||
) -> k2.Fsa:
|
||||
"""Get the decoding lattice from a decoding graph and neural
|
||||
network output.
|
||||
Args:
|
||||
nnet_output:
|
||||
It is the output of a neural model of shape `[N, T, C]`.
|
||||
HLG:
|
||||
An Fsa, the decoding graph. See also `compile_HLG.py`.
|
||||
supervision_segments:
|
||||
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
|
||||
Each row contains information for a supervision segment. Column 0
|
||||
is the `sequence_index` indicating which sequence this segment
|
||||
comes from; column 1 specifies the `start_frame` of this segment
|
||||
within the sequence; column 2 contains the `duration` of this
|
||||
segment.
|
||||
search_beam:
|
||||
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
|
||||
(less pruning). This is the default value; it may be modified by
|
||||
`min_active_states` and `max_active_states`.
|
||||
output_beam:
|
||||
Beam to prune output, similar to lattice-beam in Kaldi. Relative
|
||||
to best path of output.
|
||||
min_active_states:
|
||||
Minimum number of FSA states that are allowed to be active on any given
|
||||
frame for any given intersection/composition task. This is advisory,
|
||||
in that it will try not to have fewer than this number active.
|
||||
Set it to zero if there is no constraint.
|
||||
max_active_states:
|
||||
Maximum number of FSA states that are allowed to be active on any given
|
||||
frame for any given intersection/composition task. This is advisory,
|
||||
in that it will try not to exceed that but may not always succeed.
|
||||
You can use a very large number if no constraint is needed.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
Returns:
|
||||
A lattice containing the decoding result.
|
||||
"""
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
|
||||
)
|
||||
|
||||
lattice = k2.intersect_dense_pruned(
|
||||
HLG,
|
||||
dense_fsa_vec,
|
||||
search_beam=search_beam,
|
||||
output_beam=output_beam,
|
||||
min_active_states=min_active_states,
|
||||
max_active_states=max_active_states,
|
||||
)
|
||||
|
||||
return lattice
|
||||
|
||||
|
||||
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||
# supports List[List[int]]
|
||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
Loading…
x
Reference in New Issue
Block a user