mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 20:42:22 +00:00
Fixes after refactoring.
This commit is contained in:
parent
8623983bb7
commit
e062c1b5cf
@ -30,20 +30,14 @@ from conformer import Conformer
|
|||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import get_lattice
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
one_best_decoding, # done
|
get_lattice,
|
||||||
rescore_with_attention_decoder, # done
|
|
||||||
rescore_with_n_best_list, # done
|
|
||||||
rescore_with_whole_lattice, # done
|
|
||||||
nbest_oracle, # done
|
|
||||||
)
|
|
||||||
from icefall.decode2 import (
|
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle as nbest_oracle2,
|
nbest_oracle,
|
||||||
rescore_with_n_best_list as rescore_with_n_best_list2,
|
one_best_decoding,
|
||||||
rescore_with_whole_lattice as rescore_with_whole_lattice2,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_attention_decoder as rescore_with_attention_decoder2,
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -250,29 +244,19 @@ def decode_one_batch(
|
|||||||
# Note: You can also pass rescored lattices to it.
|
# Note: You can also pass rescored lattices to it.
|
||||||
# We choose the HLG decoded lattice for speed reasons
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
# as HLG decoding is faster and the oracle WER
|
# as HLG decoding is faster and the oracle WER
|
||||||
# is slightly worse than that of rescored lattices.
|
# is only slightly worse than that of rescored lattices.
|
||||||
if True:
|
best_path = nbest_oracle(
|
||||||
# TODO: delete the `else` branch
|
lattice=lattice,
|
||||||
best_path = nbest_oracle2(
|
num_paths=params.num_paths,
|
||||||
lattice=lattice,
|
ref_texts=supervisions["text"],
|
||||||
num_paths=params.num_paths,
|
word_table=word_table,
|
||||||
ref_texts=supervisions["text"],
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
word_table=word_table,
|
oov="<UNK>",
|
||||||
lattice_score_scale=params.lattice_score_scale,
|
)
|
||||||
oov="<UNK>",
|
hyps = get_texts(best_path)
|
||||||
)
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
hyps = get_texts(best_path)
|
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
return {key: hyps}
|
||||||
key = f"oracle_{num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
|
||||||
return {key: hyps}
|
|
||||||
else:
|
|
||||||
return nbest_oracle(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
ref_texts=supervisions["text"],
|
|
||||||
word_table=word_table,
|
|
||||||
scale=params.lattice_score_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
if params.method in ["1best", "nbest"]:
|
||||||
if params.method == "1best":
|
if params.method == "1best":
|
||||||
@ -304,65 +288,39 @@ def decode_one_batch(
|
|||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
if params.method == "nbest-rescoring":
|
if params.method == "nbest-rescoring":
|
||||||
if True:
|
best_path_dict = rescore_with_n_best_list(
|
||||||
# TODO: remove the "else" branch
|
lattice=lattice,
|
||||||
best_path_dict = rescore_with_n_best_list2(
|
G=G,
|
||||||
lattice=lattice,
|
num_paths=params.num_paths,
|
||||||
G=G,
|
lm_scale_list=lm_scale_list,
|
||||||
num_paths=params.num_paths,
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
lm_scale_list=lm_scale_list,
|
)
|
||||||
lattice_score_scale=params.lattice_score_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
best_path_dict = rescore_with_n_best_list(
|
|
||||||
lattice=lattice,
|
|
||||||
G=G,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
lm_scale_list=lm_scale_list,
|
|
||||||
scale=params.lattice_score_scale,
|
|
||||||
)
|
|
||||||
elif params.method == "whole-lattice-rescoring":
|
elif params.method == "whole-lattice-rescoring":
|
||||||
if True:
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
# TODO: remove "else" branch
|
lattice=lattice,
|
||||||
best_path_dict = rescore_with_whole_lattice2(
|
G_with_epsilon_loops=G,
|
||||||
lattice=lattice,
|
lm_scale_list=lm_scale_list,
|
||||||
G_with_epsilon_loops=G,
|
)
|
||||||
lm_scale_list=lm_scale_list,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
|
||||||
lattice=lattice,
|
|
||||||
G_with_epsilon_loops=G,
|
|
||||||
lm_scale_list=lm_scale_list,
|
|
||||||
)
|
|
||||||
elif params.method == "attention-decoder":
|
elif params.method == "attention-decoder":
|
||||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||||
rescored_lattice = rescore_with_whole_lattice(
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=None,
|
||||||
)
|
)
|
||||||
|
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||||
|
# `rescore_with_attention_decoder`
|
||||||
|
|
||||||
if True:
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
best_path_dict = rescore_with_attention_decoder2(
|
lattice=rescored_lattice,
|
||||||
lattice=rescored_lattice,
|
num_paths=params.num_paths,
|
||||||
num_paths=params.num_paths,
|
model=model,
|
||||||
model=model,
|
memory=memory,
|
||||||
memory=memory,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
sos_id=sos_id,
|
||||||
sos_id=sos_id,
|
eos_id=eos_id,
|
||||||
eos_id=eos_id,
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
lattice_score_scale=params.lattice_score_scale,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
best_path_dict = rescore_with_attention_decoder(
|
|
||||||
lattice=rescored_lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
model=model,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
scale=params.lattice_score_scale,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
|
|
||||||
|
@ -336,7 +336,7 @@ def main():
|
|||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=params.sos_id,
|
sos_id=params.sos_id,
|
||||||
eos_id=params.eos_id,
|
eos_id=params.eos_id,
|
||||||
scale=params.lattice_score_scale,
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
ngram_lm_scale=params.ngram_lm_scale,
|
ngram_lm_scale=params.ngram_lm_scale,
|
||||||
attention_scale=params.attention_decoder_scale,
|
attention_scale=params.attention_decoder_scale,
|
||||||
)
|
)
|
||||||
|
@ -229,6 +229,7 @@ def decode_one_batch(
|
|||||||
lattice=lattice,
|
lattice=lattice,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
use_double_scores=params.use_double_scores,
|
use_double_scores=params.use_double_scores,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
key = f"no_rescore-{params.num_paths}"
|
key = f"no_rescore-{params.num_paths}"
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
@ -247,10 +248,13 @@ def decode_one_batch(
|
|||||||
G=G,
|
G=G,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
|
lattice_score_scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
|
@ -64,6 +64,68 @@ def _intersect_device(
|
|||||||
return k2.cat(ans)
|
return k2.cat(ans)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lattice(
|
||||||
|
nnet_output: torch.Tensor,
|
||||||
|
HLG: k2.Fsa,
|
||||||
|
supervision_segments: torch.Tensor,
|
||||||
|
search_beam: float,
|
||||||
|
output_beam: float,
|
||||||
|
min_active_states: int,
|
||||||
|
max_active_states: int,
|
||||||
|
subsampling_factor: int = 1,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Get the decoding lattice from a decoding graph and neural
|
||||||
|
network output.
|
||||||
|
Args:
|
||||||
|
nnet_output:
|
||||||
|
It is the output of a neural model of shape `[N, T, C]`.
|
||||||
|
HLG:
|
||||||
|
An Fsa, the decoding graph. See also `compile_HLG.py`.
|
||||||
|
supervision_segments:
|
||||||
|
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
|
||||||
|
Each row contains information for a supervision segment. Column 0
|
||||||
|
is the `sequence_index` indicating which sequence this segment
|
||||||
|
comes from; column 1 specifies the `start_frame` of this segment
|
||||||
|
within the sequence; column 2 contains the `duration` of this
|
||||||
|
segment.
|
||||||
|
search_beam:
|
||||||
|
Decoding beam, e.g. 20. Smaller is faster, larger is more exact
|
||||||
|
(less pruning). This is the default value; it may be modified by
|
||||||
|
`min_active_states` and `max_active_states`.
|
||||||
|
output_beam:
|
||||||
|
Beam to prune output, similar to lattice-beam in Kaldi. Relative
|
||||||
|
to best path of output.
|
||||||
|
min_active_states:
|
||||||
|
Minimum number of FSA states that are allowed to be active on any given
|
||||||
|
frame for any given intersection/composition task. This is advisory,
|
||||||
|
in that it will try not to have fewer than this number active.
|
||||||
|
Set it to zero if there is no constraint.
|
||||||
|
max_active_states:
|
||||||
|
Maximum number of FSA states that are allowed to be active on any given
|
||||||
|
frame for any given intersection/composition task. This is advisory,
|
||||||
|
in that it will try not to exceed that but may not always succeed.
|
||||||
|
You can use a very large number if no constraint is needed.
|
||||||
|
subsampling_factor:
|
||||||
|
The subsampling factor of the model.
|
||||||
|
Returns:
|
||||||
|
A lattice containing the decoding result.
|
||||||
|
"""
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
lattice = k2.intersect_dense_pruned(
|
||||||
|
HLG,
|
||||||
|
dense_fsa_vec,
|
||||||
|
search_beam=search_beam,
|
||||||
|
output_beam=output_beam,
|
||||||
|
min_active_states=min_active_states,
|
||||||
|
max_active_states=max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
|
||||||
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
# TODO(fangjun): Use Kangwei's C++ implementation that also
|
||||||
# supports List[List[int]]
|
# supports List[List[int]]
|
||||||
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
def levenshtein_graph(symbol_ids: List[int]) -> k2.Fsa:
|
Loading…
x
Reference in New Issue
Block a user