mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
removed extra decoding_methods and params in ctc_decode
This commit is contained in:
parent
bbc163901a
commit
889d5e5dbb
@ -24,8 +24,8 @@ Usage:
|
||||
|
||||
(1) ctc-greedy-search (with cr-ctc)
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 24 \
|
||||
--epoch 60 \
|
||||
--avg 28 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
@ -47,40 +47,18 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.context_graph import ContextGraph, ContextState
|
||||
from icefall.decode import (
|
||||
ctc_greedy_search,
|
||||
ctc_prefix_beam_search,
|
||||
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||
ctc_prefix_beam_search_shallow_fussion,
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder_no_ngram,
|
||||
rescore_with_attention_decoder_with_ngram,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
@ -162,69 +140,11 @@ def get_parser():
|
||||
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
(2) ctc-prefix-beam-search. Extract n paths with the given beam, the best
|
||||
path of the n paths is the decoding result.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ilme-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG.
|
||||
It specifies the scale for the internal language model estimation.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
@ -232,42 +152,6 @@ def get_parser():
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -276,9 +160,7 @@ def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -299,10 +181,6 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -340,11 +218,16 @@ def decode_one_batch(
|
||||
hyp_tokens = []
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
hyp_tokens = ctc_greedy_search(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||
hyp_tokens = ctc_prefix_beam_search(
|
||||
ctc_output=ctc_output,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -356,20 +239,10 @@ def decode_one_batch(
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
return {"ctc-greedy-search_" + key: hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key += f"_beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ilme_scale_{params.ilme_scale}"
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||
return {"ctc-prefix-beam-search_" + key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}_" + key: hyps}
|
||||
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -377,8 +250,6 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -389,10 +260,6 @@ def decode_dataset(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -406,10 +273,7 @@ def decode_dataset(
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -421,8 +285,6 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -504,7 +366,8 @@ def main():
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
) # only support ctc-greedy-search
|
||||
"ctc-prefix-beam-search",
|
||||
) # support ctc-greedy-search and ctc-prefix-beam-search
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -522,22 +385,9 @@ def main():
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"_ilme_scale_{params.ilme_scale}"
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
if "prefix-beam-search" in params.decoding_method:
|
||||
params.suffix += f"_beam-{params.beam}"
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -551,18 +401,12 @@ def main():
|
||||
params.device = device
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
@ -648,20 +492,6 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if "LG" in params.decoding_method:
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -694,8 +524,6 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
graph_compiler=graph_compiler,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
Loading…
x
Reference in New Issue
Block a user