gigaspeech shallow fussion

This commit is contained in:
pkufool 2024-09-27 19:41:17 +08:00
parent 02e00ff504
commit 30deee2aca

View File

@ -127,6 +127,10 @@ from gigaspeech_scoring import asr_text_post_processing
from lhotse import set_caching_enabled from lhotse import set_caching_enabled
from train_cr_aed import add_model_arguments, get_model, get_params from train_cr_aed import add_model_arguments, get_model, get_params
from icefall.context_graph import ContextGraph, ContextState
from icefall.ngram_lm import NgramLm, NgramLmStateCost
from icefall.lm_wrapper import LmScorer
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -137,6 +141,7 @@ from icefall.decode import (
ctc_greedy_search, ctc_greedy_search,
ctc_prefix_beam_search, ctc_prefix_beam_search,
ctc_prefix_beam_search_attention_decoder_rescoring, ctc_prefix_beam_search_attention_decoder_rescoring,
ctc_prefix_beam_search_shallow_fussion,
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,
nbest_oracle, nbest_oracle,
@ -286,6 +291,23 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument( parser.add_argument(
"--hlg-scale", "--hlg-scale",
type=float, type=float,
@ -320,8 +342,9 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"frame_shift_ms": 10, "frame_shift_ms": 10,
"search_beam": 20, "search_beam": 20, # for k2 fsa composition
"output_beam": 8, "output_beam": 8, # for k2 fsa composition
"beam": 4, # for prefix-beam-search
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
@ -350,6 +373,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -453,6 +477,20 @@ def decode_one_batch(
ans[a_scale_str] = hyps ans[a_scale_str] = hyps
return ans return ans
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
token_ids = ctc_prefix_beam_search_shallow_fussion(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
LM=LM,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "prefix-beam-search-shallow-fussion"
return {key: hyps}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
@ -626,6 +664,7 @@ def decode_dataset(
bpe_model: Optional[spm.SentencePieceProcessor], bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -676,6 +715,7 @@ def decode_dataset(
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
LM=LM,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -779,6 +819,7 @@ def main():
"ctc-greedy-search", "ctc-greedy-search",
"prefix-beam-search", "prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring", "ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"ctc-decoding", "ctc-decoding",
"1best", "1best",
"nbest", "nbest",
@ -805,6 +846,11 @@ def main():
params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_chunk-{params.chunk_size}"
params.suffix += f"_left-context-{params.left_context_frames}" params.suffix += f"_left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
params.suffix += f"_lm-scale-{params.lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "_use-averaged-model" params.suffix += "_use-averaged-model"
@ -834,6 +880,7 @@ def main():
"ctc-decoding", "ctc-decoding",
"prefix-beam-search", "prefix-beam-search",
"ctc-prefix-beam-search-attention-decoder-rescoring", "ctc-prefix-beam-search-attention-decoder-rescoring",
"ctc-prefix-beam-search-shallow-fussion",
"attention-decoder-rescoring-no-ngram", "attention-decoder-rescoring-no-ngram",
]: ]:
HLG = None HLG = None
@ -912,6 +959,19 @@ def main():
else: else:
G = None G = None
# only load the neural network LM if required
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)