mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add context biasing for zipformer recipe (#1204)
* Add context biasing for zipformer recipe * support context biasing in modified_beam_search_LODR * fix context graph * Minor fixes
This commit is contained in:
parent
fc2df07841
commit
4d7f73ce65
@ -2389,6 +2389,7 @@ def modified_beam_search_LODR(
|
||||
LODR_lm_scale: float,
|
||||
LM: LmScorer,
|
||||
beam: int = 4,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> List[List[int]]:
|
||||
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
||||
`modified_beam_search`. It uses a bi-gram language model as the estimate
|
||||
@ -2457,6 +2458,7 @@ def modified_beam_search_LODR(
|
||||
state_cost=NgramLmStateCost(
|
||||
LODR_lm
|
||||
), # state of the source domain ngram
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
)
|
||||
)
|
||||
|
||||
@ -2602,8 +2604,17 @@ def modified_beam_search_LODR(
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k]
|
||||
|
||||
context_score = 0
|
||||
new_context_state = None if context_graph is None else hyp.context_state
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
if context_graph is not None:
|
||||
(
|
||||
context_score,
|
||||
new_context_state,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
|
||||
ys.append(new_token)
|
||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||
|
||||
@ -2619,6 +2630,7 @@ def modified_beam_search_LODR(
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
+ LODR_lm_scale * current_ngram_score
|
||||
+ context_score
|
||||
) # add the lm score
|
||||
|
||||
lm_score = scores[count]
|
||||
@ -2637,10 +2649,31 @@ def modified_beam_search_LODR(
|
||||
state=state,
|
||||
lm_score=lm_score,
|
||||
state_cost=state_cost,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
context_score, new_context_state = context_graph.finalize(
|
||||
hyp.context_state
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + context_score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
|
@ -97,6 +97,7 @@ Usage:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -122,7 +123,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -215,6 +216,7 @@ def get_parser():
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- modified_beam_search_LODR
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
@ -251,7 +253,7 @@ def get_parser():
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
Used only when --decoding-method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
@ -285,7 +287,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
Used only when --decoding-method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -347,6 +349,27 @@ def get_parser():
|
||||
help="ID of the backoff symbol in the ngram LM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="""
|
||||
The bonus score of each token for the context biasing words/phrases.
|
||||
Used only when --decoding-method is modified_beam_search and
|
||||
modified_beam_search_LODR.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path of the context biasing lists, one word/phrase each line
|
||||
Used only when --decoding-method is modified_beam_search and
|
||||
modified_beam_search_LODR.
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -359,6 +382,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
@ -388,7 +412,7 @@ def decode_one_batch(
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
LM:
|
||||
A neural network language model.
|
||||
@ -493,6 +517,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -515,6 +540,7 @@ def decode_one_batch(
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -578,16 +604,22 @@ def decode_one_batch(
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
elif params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
|
||||
return ans
|
||||
elif "modified_beam_search" in params.decoding_method:
|
||||
prefix = f"beam_size_{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search_lm_rescore",
|
||||
"modified_beam_search_lm_rescore_LODR",
|
||||
):
|
||||
ans = dict()
|
||||
assert ans_dict is not None
|
||||
for key, hyps in ans_dict.items():
|
||||
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||
ans[f"{prefix}_{key}"] = hyps
|
||||
return ans
|
||||
else:
|
||||
if params.has_contexts:
|
||||
prefix += f"-context-score-{params.context_score}"
|
||||
return {prefix: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
@ -599,6 +631,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
LM: Optional[LmScorer] = None,
|
||||
ngram_lm=None,
|
||||
ngram_lm_scale: float = 0.0,
|
||||
@ -618,7 +651,7 @@ def decode_dataset(
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
@ -649,6 +682,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
LM=LM,
|
||||
@ -744,6 +778,11 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if os.path.exists(params.context_file):
|
||||
params.has_contexts = True
|
||||
else:
|
||||
params.has_contexts = False
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
@ -770,6 +809,12 @@ def main():
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
if params.decoding_method in (
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
):
|
||||
if params.has_contexts:
|
||||
params.suffix += f"-context-score-{params.context_score}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -952,6 +997,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if "modified_beam_search" in params.decoding_method:
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build(sp.encode(contexts))
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -976,6 +1033,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
LM=LM,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
|
@ -29,7 +29,7 @@ class ContextState:
|
||||
token: int,
|
||||
token_score: float,
|
||||
node_score: float,
|
||||
local_node_score: float,
|
||||
output_score: float,
|
||||
is_end: bool,
|
||||
):
|
||||
"""Create a ContextState.
|
||||
@ -40,16 +40,15 @@ class ContextState:
|
||||
The id of the root node is always 0.
|
||||
token:
|
||||
The token id.
|
||||
score:
|
||||
token_score:
|
||||
The bonus for each token during decoding, which will hopefully
|
||||
boost the token up to survive beam search.
|
||||
node_score:
|
||||
The accumulated bonus from root of graph to current node, it will be
|
||||
used to calculate the score for fail arc.
|
||||
local_node_score:
|
||||
The accumulated bonus from last ``end_node``(node with is_end true)
|
||||
to current_node, it will be used to calculate the score for fail arc.
|
||||
Node: The local_node_score of a ``end_node`` is 0.
|
||||
output_score:
|
||||
The total scores of matched phrases, sum of the node_score of all
|
||||
the output node for current node.
|
||||
is_end:
|
||||
True if current token is the end of a context.
|
||||
"""
|
||||
@ -57,7 +56,7 @@ class ContextState:
|
||||
self.token = token
|
||||
self.token_score = token_score
|
||||
self.node_score = node_score
|
||||
self.local_node_score = local_node_score
|
||||
self.output_score = output_score
|
||||
self.is_end = is_end
|
||||
self.next = {}
|
||||
self.fail = None
|
||||
@ -93,7 +92,7 @@ class ContextGraph:
|
||||
token=-1,
|
||||
token_score=0,
|
||||
node_score=0,
|
||||
local_node_score=0,
|
||||
output_score=0,
|
||||
is_end=False,
|
||||
)
|
||||
self.root.fail = self.root
|
||||
@ -131,6 +130,7 @@ class ContextGraph:
|
||||
output = None
|
||||
break
|
||||
node.output = output
|
||||
node.output_score += 0 if output is None else output.output_score
|
||||
queue.append(node)
|
||||
|
||||
def build(self, token_ids: List[List[int]]):
|
||||
@ -153,14 +153,13 @@ class ContextGraph:
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
is_end = i == len(tokens) - 1
|
||||
node_score = node.node_score + self.context_score
|
||||
node.next[token] = ContextState(
|
||||
id=self.num_nodes,
|
||||
token=token,
|
||||
token_score=self.context_score,
|
||||
node_score=node.node_score + self.context_score,
|
||||
local_node_score=0
|
||||
if is_end
|
||||
else (node.local_node_score + self.context_score),
|
||||
node_score=node_score,
|
||||
output_score=node_score if is_end else 0,
|
||||
is_end=is_end,
|
||||
)
|
||||
node = node.next[token]
|
||||
@ -186,8 +185,6 @@ class ContextGraph:
|
||||
if token in state.next:
|
||||
node = state.next[token]
|
||||
score = node.token_score
|
||||
if state.is_end:
|
||||
score += state.node_score
|
||||
else:
|
||||
# token not matched
|
||||
# We will trace along the fail arc until it matches the token or reaching
|
||||
@ -202,14 +199,9 @@ class ContextGraph:
|
||||
node = node.next[token]
|
||||
|
||||
# The score of the fail path
|
||||
score = node.node_score - state.local_node_score
|
||||
score = node.node_score - state.node_score
|
||||
assert node is not None
|
||||
matched_score = 0
|
||||
output = node.output
|
||||
while output is not None:
|
||||
matched_score += output.node_score
|
||||
output = output.output
|
||||
return (score + matched_score, node)
|
||||
return (score + node.output_score, node)
|
||||
|
||||
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
|
||||
"""When reaching the end of the decoded sequence, we need to finalize
|
||||
@ -227,8 +219,6 @@ class ContextGraph:
|
||||
"""
|
||||
# The score of the fail arc
|
||||
score = -state.node_score
|
||||
if state.is_end:
|
||||
score = 0
|
||||
return (score, self.root)
|
||||
|
||||
def draw(
|
||||
@ -307,10 +297,8 @@ class ContextGraph:
|
||||
for token, node in current_node.next.items():
|
||||
if node.id not in seen:
|
||||
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
|
||||
local_node_score = f"{node.local_node_score:.2f}".rstrip(
|
||||
"0"
|
||||
).rstrip(".")
|
||||
label = f"{node.id}/({node_score},{local_node_score})"
|
||||
output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
|
||||
label = f"{node.id}/({node_score}, {output_score})"
|
||||
if node.is_end:
|
||||
dot.node(str(node.id), label=label, **final_state_attr)
|
||||
else:
|
||||
@ -391,6 +379,7 @@ if __name__ == "__main__":
|
||||
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
|
||||
"HISHE": 9, # "HIS", "S", "SHE", "HE"
|
||||
"SHED": 6, # "S", "SHE", "HE"
|
||||
"SHELF": 6, # "S", "SHE", "HE"
|
||||
"HELL": 2, # "HE"
|
||||
"HELLO": 7, # "HE", "HELLO"
|
||||
"DHRHISQ": 4, # "HIS", "S"
|
||||
|
Loading…
x
Reference in New Issue
Block a user