mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
removed extra decoding_methods and params in ctc_decode
This commit is contained in:
parent
bbc163901a
commit
889d5e5dbb
@ -24,8 +24,8 @@ Usage:
|
|||||||
|
|
||||||
(1) ctc-greedy-search (with cr-ctc)
|
(1) ctc-greedy-search (with cr-ctc)
|
||||||
./zipformer/ctc_decode.py \
|
./zipformer/ctc_decode.py \
|
||||||
--epoch 50 \
|
--epoch 60 \
|
||||||
--avg 24 \
|
--avg 28 \
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--use-cr-ctc 1 \
|
--use-cr-ctc 1 \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
@ -47,40 +47,18 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from beam_search import (
|
|
||||||
beam_search,
|
|
||||||
fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_LG,
|
|
||||||
fast_beam_search_nbest_oracle,
|
|
||||||
fast_beam_search_one_best,
|
|
||||||
greedy_search,
|
|
||||||
greedy_search_batch,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.context_graph import ContextGraph, ContextState
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
ctc_greedy_search,
|
ctc_greedy_search,
|
||||||
ctc_prefix_beam_search,
|
ctc_prefix_beam_search,
|
||||||
ctc_prefix_beam_search_attention_decoder_rescoring,
|
|
||||||
ctc_prefix_beam_search_shallow_fussion,
|
|
||||||
get_lattice,
|
|
||||||
nbest_decoding,
|
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
|
||||||
rescore_with_attention_decoder_no_ngram,
|
|
||||||
rescore_with_attention_decoder_with_ngram,
|
|
||||||
rescore_with_n_best_list,
|
|
||||||
rescore_with_whole_lattice,
|
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
@ -162,69 +140,11 @@ def get_parser():
|
|||||||
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
||||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
It needs neither a lexicon nor an n-gram LM.
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
|
(2) ctc-prefix-beam-search. Extract n paths with the given beam, the best
|
||||||
|
path of the n paths is the decoding result.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam-size",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
|
||||||
frame. Used only when --decoding-method is beam_search or
|
|
||||||
modified_beam_search.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--beam",
|
|
||||||
type=float,
|
|
||||||
default=20.0,
|
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
||||||
`beam` in Kaldi.
|
|
||||||
Used only when --decoding-method is fast_beam_search,
|
|
||||||
fast_beam_search, fast_beam_search_LG,
|
|
||||||
and fast_beam_search_nbest_oracle
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ngram-lm-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.01,
|
|
||||||
help="""
|
|
||||||
Used only when --decoding_method is fast_beam_search_LG.
|
|
||||||
It specifies the scale for n-gram LM scores.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--ilme-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.2,
|
|
||||||
help="""
|
|
||||||
Used only when --decoding_method is fast_beam_search_LG.
|
|
||||||
It specifies the scale for the internal language model estimation.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-contexts",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
|
||||||
and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-states",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search, fast_beam_search, fast_beam_search_LG,
|
|
||||||
and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -232,42 +152,6 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-sym-per-frame",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""Maximum number of symbols per frame.
|
|
||||||
Used only when --decoding_method is greedy_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-paths",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="""Number of paths for nbest decoding.
|
|
||||||
Used only when the decoding method is fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--nbest-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="""Scale applied to lattice scores when computing nbest paths.
|
|
||||||
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--blank-penalty",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="""
|
|
||||||
The penalty applied on blank symbol during decoding.
|
|
||||||
Note: It is a positive value that would be applied to logits like
|
|
||||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
|
||||||
[batch_size, vocab] and blank id is 0).
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -276,9 +160,7 @@ def decode_one_batch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -299,10 +181,6 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -340,11 +218,16 @@ def decode_one_batch(
|
|||||||
hyp_tokens = []
|
hyp_tokens = []
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1:
|
if params.decoding_method == "ctc-greedy-search":
|
||||||
hyp_tokens = ctc_greedy_search(
|
hyp_tokens = ctc_greedy_search(
|
||||||
ctc_output=ctc_output,
|
ctc_output=ctc_output,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||||
|
hyp_tokens = ctc_prefix_beam_search(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -356,20 +239,10 @@ def decode_one_batch(
|
|||||||
key = f"blank_penalty_{params.blank_penalty}"
|
key = f"blank_penalty_{params.blank_penalty}"
|
||||||
if params.decoding_method == "ctc-greedy-search":
|
if params.decoding_method == "ctc-greedy-search":
|
||||||
return {"ctc-greedy-search_" + key: hyps}
|
return {"ctc-greedy-search_" + key: hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||||
key += f"_beam_{params.beam}_"
|
return {"ctc-prefix-beam-search_" + key: hyps}
|
||||||
key += f"max_contexts_{params.max_contexts}_"
|
|
||||||
key += f"max_states_{params.max_states}"
|
|
||||||
if "nbest" in params.decoding_method:
|
|
||||||
key += f"_num_paths_{params.num_paths}_"
|
|
||||||
key += f"nbest_scale_{params.nbest_scale}"
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
key += f"_ilme_scale_{params.ilme_scale}"
|
|
||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
|
||||||
|
|
||||||
return {key: hyps}
|
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}_" + key: hyps}
|
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -377,8 +250,6 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -389,10 +260,6 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
decoding_graph:
|
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -406,10 +273,7 @@ def decode_dataset(
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "ctc-greedy-search":
|
log_interval = 20
|
||||||
log_interval = 50
|
|
||||||
else:
|
|
||||||
log_interval = 20
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -421,8 +285,6 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -504,7 +366,8 @@ def main():
|
|||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
"ctc-greedy-search",
|
||||||
) # only support ctc-greedy-search
|
"ctc-prefix-beam-search",
|
||||||
|
) # support ctc-greedy-search and ctc-prefix-beam-search
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -522,22 +385,9 @@ def main():
|
|||||||
params.suffix += f"-chunk-{params.chunk_size}"
|
params.suffix += f"-chunk-{params.chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"_beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
|
||||||
if "nbest" in params.decoding_method:
|
|
||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
params.suffix += f"_ilme_scale_{params.ilme_scale}"
|
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
|
||||||
elif "beam_search" in params.decoding_method:
|
|
||||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
else:
|
|
||||||
params.suffix += f"-context-{params.context_size}"
|
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
|
||||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
@ -551,18 +401,12 @@ def main():
|
|||||||
params.device = device
|
params.device = device
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
|
||||||
lexicon=lexicon,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -648,20 +492,6 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
|
||||||
if "LG" in params.decoding_method:
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
lg_filename = params.lang_dir / "LG.pt"
|
|
||||||
logging.info(f"Loading {lg_filename}")
|
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
|
||||||
torch.load(lg_filename, map_location=device)
|
|
||||||
)
|
|
||||||
decoding_graph.scores *= params.ngram_lm_scale
|
|
||||||
else:
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
||||||
else:
|
|
||||||
decoding_graph = None
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -694,8 +524,6 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user