mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
include previous added decoding method
This commit is contained in:
parent
6c8d1f9ef5
commit
9a01b9098d
@ -131,11 +131,13 @@ from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_ngram_rescoring,
|
||||
modified_beam_search_rnnlm_shallow_fusion,
|
||||
)
|
||||
from librispeech import LibriSpeech
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -232,6 +234,7 @@ def get_parser():
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_ngram_rescoring
|
||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
@ -386,7 +389,23 @@ def get_parser():
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--ilm-scale", type=float, default=-0.1)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Token Ngram used for rescoring.
|
||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""ID of the backoff symbol.
|
||||
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -399,6 +418,8 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
@ -534,6 +555,17 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
|
||||
hyp_tokens = modified_beam_search_ngram_rescoring(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
||||
model=model,
|
||||
@ -595,9 +627,11 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -638,13 +672,6 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration = sum(
|
||||
[cut.duration for cut in batch["supervisions"]["cut"]]
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}"
|
||||
)
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -653,6 +680,8 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
@ -680,7 +709,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
@ -740,6 +769,7 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_ngram_rescoring",
|
||||
"modified_beam_search_rnnlm_shallow_fusion",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
@ -765,13 +795,10 @@ def main():
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
if "rnnlm" in params.decoding_method:
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||
|
||||
if "ILME" in params.decoding_method:
|
||||
params.suffix += f"-ILME-scale={params.ilm_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
@ -884,6 +911,14 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
# only load rnnlm if used
|
||||
if "rnnlm" in params.decoding_method:
|
||||
rnn_lm_scale = params.rnn_lm_scale
|
||||
@ -951,6 +986,8 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
rnnlm=rnn_lm_model,
|
||||
rnnlm_scale=rnn_lm_scale,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user