Update ctc_decode.py

This commit is contained in:
jinzr 2023-10-17 16:40:41 +08:00
parent 9be6782c28
commit 76c311caf4

View File

@ -54,14 +54,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import ( from icefall.decode import get_lattice, one_best_decoding
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -204,12 +197,9 @@ def get_decoding_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa], H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict, batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -234,8 +224,6 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H: H:
The ctc topo. Used only when params.decoding_method is ctc-decoding. The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model: bpe_model:
@ -254,10 +242,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None. the returned dict. Note: If it decodes to nothing, then return None.
""" """
if HLG is not None: device = H.device
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -296,13 +281,8 @@ def decode_one_batch(
1, 1,
).to(torch.int32) ).to(torch.int32)
if H is None: assert bpe_model is not None
assert HLG is not None decoding_graph = H
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice( lattice = get_lattice(
nnet_output=ctc_output, nnet_output=ctc_output,
@ -333,79 +313,6 @@ def decode_one_batch(
key = "ctc-decoding" key = "ctc-decoding"
return {key: hyps} return {key: hyps}
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.decoding_method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset( def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,