include previous added decoding method

This commit is contained in:
marcoyang 2022-11-02 18:03:56 +08:00
parent 6c8d1f9ef5
commit 9a01b9098d

View File

@ -131,11 +131,13 @@ from beam_search import (
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_ngram_rescoring,
modified_beam_search_rnnlm_shallow_fusion, modified_beam_search_rnnlm_shallow_fusion,
) )
from librispeech import LibriSpeech from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -232,6 +234,7 @@ def get_parser():
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
@ -386,7 +389,23 @@ def get_parser():
last output linear layer last output linear layer
""", """,
) )
parser.add_argument("--ilm-scale", type=float, default=-0.1)
parser.add_argument(
"--tokens-ngram",
type=int,
default=3,
help="""Token Ngram used for rescoring.
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -399,6 +418,8 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None, rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0, rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
@ -534,6 +555,17 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
hyp_tokens = modified_beam_search_ngram_rescoring(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model, model=model,
@ -595,9 +627,11 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None, rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0, rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -638,13 +672,6 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration = sum(
[cut.duration for cut in batch["supervisions"]["cut"]]
)
logging.info(
f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}"
)
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -653,6 +680,8 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnnlm, rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale, rnnlm_scale=rnnlm_scale,
) )
@ -680,7 +709,7 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
@ -740,6 +769,7 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_ngram_rescoring",
"modified_beam_search_rnnlm_shallow_fusion", "modified_beam_search_rnnlm_shallow_fusion",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -765,13 +795,10 @@ def main():
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if "rnnlm" in params.decoding_method: if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "ILME" in params.decoding_method:
params.suffix += f"-ILME-scale={params.ilm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -884,6 +911,14 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
# only load rnnlm if used # only load rnnlm if used
if "rnnlm" in params.decoding_method: if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale rnn_lm_scale = params.rnn_lm_scale
@ -951,6 +986,8 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
rnnlm=rnn_lm_model, rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale, rnnlm_scale=rnn_lm_scale,
) )