mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
gigaspeech shallow fussion
This commit is contained in:
parent
02e00ff504
commit
30deee2aca
@ -127,6 +127,10 @@ from gigaspeech_scoring import asr_text_post_processing
|
|||||||
from lhotse import set_caching_enabled
|
from lhotse import set_caching_enabled
|
||||||
from train_cr_aed import add_model_arguments, get_model, get_params
|
from train_cr_aed import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.context_graph import ContextGraph, ContextState
|
||||||
|
from icefall.ngram_lm import NgramLm, NgramLmStateCost
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -137,6 +141,7 @@ from icefall.decode import (
|
|||||||
ctc_greedy_search,
|
ctc_greedy_search,
|
||||||
ctc_prefix_beam_search,
|
ctc_prefix_beam_search,
|
||||||
ctc_prefix_beam_search_attention_decoder_rescoring,
|
ctc_prefix_beam_search_attention_decoder_rescoring,
|
||||||
|
ctc_prefix_beam_search_shallow_fussion,
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
@ -286,6 +291,23 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hlg-scale",
|
"--hlg-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -320,8 +342,9 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"search_beam": 20,
|
"search_beam": 20, # for k2 fsa composition
|
||||||
"output_beam": 8,
|
"output_beam": 8, # for k2 fsa composition
|
||||||
|
"beam": 4, # for prefix-beam-search
|
||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
@ -350,6 +373,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -453,6 +477,20 @@ def decode_one_batch(
|
|||||||
ans[a_scale_str] = hyps
|
ans[a_scale_str] = hyps
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
token_ids = ctc_prefix_beam_search_shallow_fussion(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
key = "prefix-beam-search-shallow-fussion"
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
@ -626,6 +664,7 @@ def decode_dataset(
|
|||||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -676,6 +715,7 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
G=G,
|
G=G,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -779,6 +819,7 @@ def main():
|
|||||||
"ctc-greedy-search",
|
"ctc-greedy-search",
|
||||||
"prefix-beam-search",
|
"prefix-beam-search",
|
||||||
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"1best",
|
"1best",
|
||||||
"nbest",
|
"nbest",
|
||||||
@ -805,6 +846,11 @@ def main():
|
|||||||
params.suffix += f"_chunk-{params.chunk_size}"
|
params.suffix += f"_chunk-{params.chunk_size}"
|
||||||
params.suffix += f"_left-context-{params.left_context_frames}"
|
params.suffix += f"_left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
|
params.suffix += f"_beam-{params.beam}"
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
params.suffix += f"_lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "_use-averaged-model"
|
params.suffix += "_use-averaged-model"
|
||||||
|
|
||||||
@ -834,6 +880,7 @@ def main():
|
|||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"prefix-beam-search",
|
"prefix-beam-search",
|
||||||
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
"ctc-prefix-beam-search-attention-decoder-rescoring",
|
||||||
|
"ctc-prefix-beam-search-shallow-fussion",
|
||||||
"attention-decoder-rescoring-no-ngram",
|
"attention-decoder-rescoring-no-ngram",
|
||||||
]:
|
]:
|
||||||
HLG = None
|
HLG = None
|
||||||
@ -912,6 +959,19 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
|
# only load the neural network LM if required
|
||||||
|
if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion":
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user