mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 10:04:21 +00:00
Add context biasing for zipformer recipe
This commit is contained in:
parent
1ee251c8b3
commit
88067f7566
@ -97,6 +97,7 @@ Usage:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -122,7 +123,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -347,6 +348,25 @@ def get_parser():
|
||||
help="ID of the backoff symbol in the ngram LM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding_method is modified_beam_search.
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -359,6 +379,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
@ -493,6 +514,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -503,6 +525,7 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LM=LM,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -515,6 +538,7 @@ def decode_one_batch(
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -527,6 +551,7 @@ def decode_one_batch(
|
||||
beam=params.beam_size,
|
||||
LM=LM,
|
||||
lm_scale_list=lm_scale_list,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
||||
@ -539,6 +564,7 @@ def decode_one_batch(
|
||||
LODR_lm=ngram_lm,
|
||||
sp=sp,
|
||||
lm_scale_list=lm_scale_list,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
@ -578,16 +604,24 @@ def decode_one_batch(
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
elif params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
|
||||
return ans
|
||||
elif "modified_beam_search" in params.decoding_method:
|
||||
prefix = f"beam_size_{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
prefix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
prefix += "-no-context-words"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||
ans[f"{prefix}_{key}"] = hyps
|
||||
return ans
|
||||
else:
|
||||
return {prefix: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
@ -599,6 +633,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
@ -649,6 +684,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
LM=LM,
|
||||
@ -744,6 +780,11 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
@ -770,6 +811,10 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += "-no-context-words"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -952,6 +997,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if "modified_beam_search" in params.decoding_method:
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -976,6 +1033,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
LM=LM,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
|
Loading…
x
Reference in New Issue
Block a user