removed extra decoding_methods and params in ctc_decode

This commit is contained in:
hhzzff 2025-07-07 11:28:48 +08:00
parent bbc163901a
commit 889d5e5dbb

View File

@ -24,8 +24,8 @@ Usage:
(1) ctc-greedy-search (with cr-ctc)
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 24 \
--epoch 60 \
--avg 28 \
--exp-dir ./zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
@ -47,40 +47,18 @@ import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.context_graph import ContextGraph, ContextState
from icefall.decode import (
ctc_greedy_search,
ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder_no_ngram,
rescore_with_attention_decoder_with_ngram,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -162,69 +140,11 @@ def get_parser():
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
(2) ctc-prefix-beam-search. Extract n paths with the given beam, the best
path of the n paths is the decoding result.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
@ -232,42 +152,6 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
@ -276,9 +160,7 @@ def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -299,10 +181,6 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -340,11 +218,16 @@ def decode_one_batch(
hyp_tokens = []
hyps = []
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
if params.decoding_method == "ctc-greedy-search":
hyp_tokens = ctc_greedy_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "ctc-prefix-beam-search":
hyp_tokens = ctc_prefix_beam_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -356,20 +239,10 @@ def decode_one_batch(
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "ctc-greedy-search":
return {"ctc-greedy-search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif params.decoding_method == "ctc-prefix-beam-search":
return {"ctc-prefix-beam-search_" + key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: hyps}
assert False, f"Unsupported decoding method: {params.decoding_method}"
def decode_dataset(
@ -377,8 +250,6 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -389,10 +260,6 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -406,10 +273,7 @@ def decode_dataset(
except TypeError:
num_batches = "?"
if params.decoding_method == "ctc-greedy-search":
log_interval = 50
else:
log_interval = 20
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -421,8 +285,6 @@ def decode_dataset(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
@ -504,7 +366,8 @@ def main():
assert params.decoding_method in (
"ctc-greedy-search",
) # only support ctc-greedy-search
"ctc-prefix-beam-search",
) # support ctc-greedy-search and ctc-prefix-beam-search
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
@ -522,22 +385,9 @@ def main():
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
params.suffix += f"-context-{params.context_size}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -551,18 +401,12 @@ def main():
params.device = device
logging.info(f"Device: {device}")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
@ -648,20 +492,6 @@ def main():
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -694,8 +524,6 @@ def main():
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(