Add context biasing for zipformer recipe (#1204)

* Add context biasing for zipformer recipe

* support context biasing in modified_beam_search_LODR

* fix context graph

* Minor fixes
This commit is contained in:
Wei Kang 2023-08-28 19:37:32 +08:00 committed by GitHub
parent fc2df07841
commit 4d7f73ce65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 42 deletions

View File

@ -2389,6 +2389,7 @@ def modified_beam_search_LODR(
LODR_lm_scale: float,
LM: LmScorer,
beam: int = 4,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
`modified_beam_search`. It uses a bi-gram language model as the estimate
@ -2457,6 +2458,7 @@ def modified_beam_search_LODR(
state_cost=NgramLmStateCost(
LODR_lm
), # state of the source domain ngram
context_state=None if context_graph is None else context_graph.root,
)
)
@ -2602,8 +2604,17 @@ def modified_beam_search_LODR(
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id):
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
@ -2619,6 +2630,7 @@ def modified_beam_search_LODR(
hyp_log_prob += (
lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score
+ context_score
) # add the lm score
lm_score = scores[count]
@ -2637,10 +2649,31 @@ def modified_beam_search_LODR(
state=state,
lm_score=lm_score,
state_cost=state_cost,
context_state=new_context_state,
)
B[i].add(new_hyp)
B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -97,6 +97,7 @@ Usage:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
)
from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -215,6 +216,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- modified_beam_search_LODR
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
@ -251,7 +253,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding-method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -285,7 +287,7 @@ def get_parser():
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
Used only when --decoding-method is greedy_search""",
)
parser.add_argument(
@ -347,6 +349,27 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser)
return parser
@ -359,6 +382,7 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
@ -388,7 +412,7 @@ def decode_one_batch(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
@ -493,6 +517,7 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -515,6 +540,7 @@ def decode_one_batch(
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -578,16 +604,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
return ans
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +631,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
@ -618,7 +651,7 @@ def decode_dataset(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -649,6 +682,7 @@ def decode_dataset(
model=model,
sp=sp,
decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table,
batch=batch,
LM=LM,
@ -744,6 +778,11 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
@ -770,6 +809,12 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None
word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,

View File

@ -29,7 +29,7 @@ class ContextState:
token: int,
token_score: float,
node_score: float,
local_node_score: float,
output_score: float,
is_end: bool,
):
"""Create a ContextState.
@ -40,16 +40,15 @@ class ContextState:
The id of the root node is always 0.
token:
The token id.
score:
token_score:
The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search.
node_score:
The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc.
local_node_score:
The accumulated bonus from last ``end_node``(node with is_end true)
to current_node, it will be used to calculate the score for fail arc.
Node: The local_node_score of a ``end_node`` is 0.
output_score:
The total scores of matched phrases, sum of the node_score of all
the output node for current node.
is_end:
True if current token is the end of a context.
"""
@ -57,7 +56,7 @@ class ContextState:
self.token = token
self.token_score = token_score
self.node_score = node_score
self.local_node_score = local_node_score
self.output_score = output_score
self.is_end = is_end
self.next = {}
self.fail = None
@ -93,7 +92,7 @@ class ContextGraph:
token=-1,
token_score=0,
node_score=0,
local_node_score=0,
output_score=0,
is_end=False,
)
self.root.fail = self.root
@ -131,6 +130,7 @@ class ContextGraph:
output = None
break
node.output = output
node.output_score += 0 if output is None else output.output_score
queue.append(node)
def build(self, token_ids: List[List[int]]):
@ -153,14 +153,13 @@ class ContextGraph:
if token not in node.next:
self.num_nodes += 1
is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node.node_score + self.context_score,
local_node_score=0
if is_end
else (node.local_node_score + self.context_score),
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node = node.next[token]
@ -186,8 +185,6 @@ class ContextGraph:
if token in state.next:
node = state.next[token]
score = node.token_score
if state.is_end:
score += state.node_score
else:
# token not matched
# We will trace along the fail arc until it matches the token or reaching
@ -202,14 +199,9 @@ class ContextGraph:
node = node.next[token]
# The score of the fail path
score = node.node_score - state.local_node_score
score = node.node_score - state.node_score
assert node is not None
matched_score = 0
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
return (score + node.output_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize
@ -227,8 +219,6 @@ class ContextGraph:
"""
# The score of the fail arc
score = -state.node_score
if state.is_end:
score = 0
return (score, self.root)
def draw(
@ -307,10 +297,8 @@ class ContextGraph:
for token, node in current_node.next.items():
if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip(
"0"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
label = f"{node.id}/({node_score}, {output_score})"
if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr)
else:
@ -391,6 +379,7 @@ if __name__ == "__main__":
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"