Add context biasing for zipformer recipe (#1204)

* Add context biasing for zipformer recipe

* support context biasing in modified_beam_search_LODR

* fix context graph

* Minor fixes
This commit is contained in:
Wei Kang 2023-08-28 19:37:32 +08:00 committed by GitHub
parent fc2df07841
commit 4d7f73ce65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 122 additions and 42 deletions

View File

@ -2389,6 +2389,7 @@ def modified_beam_search_LODR(
LODR_lm_scale: float, LODR_lm_scale: float,
LM: LmScorer, LM: LmScorer,
beam: int = 4, beam: int = 4,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]: ) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with """This function implements LODR (https://arxiv.org/abs/2203.16776) with
`modified_beam_search`. It uses a bi-gram language model as the estimate `modified_beam_search`. It uses a bi-gram language model as the estimate
@ -2457,6 +2458,7 @@ def modified_beam_search_LODR(
state_cost=NgramLmStateCost( state_cost=NgramLmStateCost(
LODR_lm LODR_lm
), # state of the source domain ngram ), # state of the source domain ngram
context_state=None if context_graph is None else context_graph.root,
) )
) )
@ -2602,8 +2604,17 @@ def modified_beam_search_LODR(
hyp_log_prob = topk_log_probs[k] # get score of current hyp hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
ys.append(new_token) ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token) state_cost = hyp.state_cost.forward_one_step(new_token)
@ -2619,6 +2630,7 @@ def modified_beam_search_LODR(
hyp_log_prob += ( hyp_log_prob += (
lm_score[new_token] * lm_scale lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score + LODR_lm_scale * current_ngram_score
+ context_score
) # add the lm score ) # add the lm score
lm_score = scores[count] lm_score = scores[count]
@ -2637,10 +2649,31 @@ def modified_beam_search_LODR(
state=state, state=state,
lm_score=lm_score, lm_score=lm_score,
state_cost=state_cost, state_cost=state_cost,
context_state=new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -97,6 +97,7 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -215,6 +216,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- modified_beam_search_LODR
- fast_beam_search - fast_beam_search
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
@ -251,7 +253,7 @@ def get_parser():
type=float, type=float,
default=0.01, default=0.01,
help=""" help="""
Used only when --decoding_method is fast_beam_search_nbest_LG. Used only when --decoding-method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores. It specifies the scale for n-gram LM scores.
""", """,
) )
@ -285,7 +287,7 @@ def get_parser():
type=int, type=int,
default=1, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding-method is greedy_search""",
) )
parser.add_argument( parser.add_argument(
@ -347,6 +349,27 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM", help="ID of the backoff symbol in the ngram LM",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -359,6 +382,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -388,7 +412,7 @@ def decode_one_batch(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM: LM:
A neural network language model. A neural network language model.
@ -493,6 +517,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -515,6 +540,7 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale, LODR_lm_scale=ngram_lm_scale,
LM=LM, LM=LM,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -578,16 +604,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
elif params.decoding_method in ( elif "modified_beam_search" in params.decoding_method:
"modified_beam_search_lm_rescore", prefix = f"beam_size_{params.beam_size}"
"modified_beam_search_lm_rescore_LODR", if params.decoding_method in (
): "modified_beam_search_lm_rescore",
ans = dict() "modified_beam_search_lm_rescore_LODR",
assert ans_dict is not None ):
for key, hyps in ans_dict.items(): ans = dict()
hyps = [sp.decode(hyp).split() for hyp in hyps] assert ans_dict is not None
ans[f"beam_size_{params.beam_size}_{key}"] = hyps for key, hyps in ans_dict.items():
return ans hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +631,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None, LM: Optional[LmScorer] = None,
ngram_lm=None, ngram_lm=None,
ngram_lm_scale: float = 0.0, ngram_lm_scale: float = 0.0,
@ -618,7 +651,7 @@ def decode_dataset(
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
@ -649,6 +682,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
LM=LM, LM=LM,
@ -744,6 +778,11 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
@ -770,6 +809,12 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM, LM=LM,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,

View File

@ -29,7 +29,7 @@ class ContextState:
token: int, token: int,
token_score: float, token_score: float,
node_score: float, node_score: float,
local_node_score: float, output_score: float,
is_end: bool, is_end: bool,
): ):
"""Create a ContextState. """Create a ContextState.
@ -40,16 +40,15 @@ class ContextState:
The id of the root node is always 0. The id of the root node is always 0.
token: token:
The token id. The token id.
score: token_score:
The bonus for each token during decoding, which will hopefully The bonus for each token during decoding, which will hopefully
boost the token up to survive beam search. boost the token up to survive beam search.
node_score: node_score:
The accumulated bonus from root of graph to current node, it will be The accumulated bonus from root of graph to current node, it will be
used to calculate the score for fail arc. used to calculate the score for fail arc.
local_node_score: output_score:
The accumulated bonus from last ``end_node``(node with is_end true) The total scores of matched phrases, sum of the node_score of all
to current_node, it will be used to calculate the score for fail arc. the output node for current node.
Node: The local_node_score of a ``end_node`` is 0.
is_end: is_end:
True if current token is the end of a context. True if current token is the end of a context.
""" """
@ -57,7 +56,7 @@ class ContextState:
self.token = token self.token = token
self.token_score = token_score self.token_score = token_score
self.node_score = node_score self.node_score = node_score
self.local_node_score = local_node_score self.output_score = output_score
self.is_end = is_end self.is_end = is_end
self.next = {} self.next = {}
self.fail = None self.fail = None
@ -93,7 +92,7 @@ class ContextGraph:
token=-1, token=-1,
token_score=0, token_score=0,
node_score=0, node_score=0,
local_node_score=0, output_score=0,
is_end=False, is_end=False,
) )
self.root.fail = self.root self.root.fail = self.root
@ -131,6 +130,7 @@ class ContextGraph:
output = None output = None
break break
node.output = output node.output = output
node.output_score += 0 if output is None else output.output_score
queue.append(node) queue.append(node)
def build(self, token_ids: List[List[int]]): def build(self, token_ids: List[List[int]]):
@ -153,14 +153,13 @@ class ContextGraph:
if token not in node.next: if token not in node.next:
self.num_nodes += 1 self.num_nodes += 1
is_end = i == len(tokens) - 1 is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState( node.next[token] = ContextState(
id=self.num_nodes, id=self.num_nodes,
token=token, token=token,
token_score=self.context_score, token_score=self.context_score,
node_score=node.node_score + self.context_score, node_score=node_score,
local_node_score=0 output_score=node_score if is_end else 0,
if is_end
else (node.local_node_score + self.context_score),
is_end=is_end, is_end=is_end,
) )
node = node.next[token] node = node.next[token]
@ -186,8 +185,6 @@ class ContextGraph:
if token in state.next: if token in state.next:
node = state.next[token] node = state.next[token]
score = node.token_score score = node.token_score
if state.is_end:
score += state.node_score
else: else:
# token not matched # token not matched
# We will trace along the fail arc until it matches the token or reaching # We will trace along the fail arc until it matches the token or reaching
@ -202,14 +199,9 @@ class ContextGraph:
node = node.next[token] node = node.next[token]
# The score of the fail path # The score of the fail path
score = node.node_score - state.local_node_score score = node.node_score - state.node_score
assert node is not None assert node is not None
matched_score = 0 return (score + node.output_score, node)
output = node.output
while output is not None:
matched_score += output.node_score
output = output.output
return (score + matched_score, node)
def finalize(self, state: ContextState) -> Tuple[float, ContextState]: def finalize(self, state: ContextState) -> Tuple[float, ContextState]:
"""When reaching the end of the decoded sequence, we need to finalize """When reaching the end of the decoded sequence, we need to finalize
@ -227,8 +219,6 @@ class ContextGraph:
""" """
# The score of the fail arc # The score of the fail arc
score = -state.node_score score = -state.node_score
if state.is_end:
score = 0
return (score, self.root) return (score, self.root)
def draw( def draw(
@ -307,10 +297,8 @@ class ContextGraph:
for token, node in current_node.next.items(): for token, node in current_node.next.items():
if node.id not in seen: if node.id not in seen:
node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".")
local_node_score = f"{node.local_node_score:.2f}".rstrip( output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".")
"0" label = f"{node.id}/({node_score}, {output_score})"
).rstrip(".")
label = f"{node.id}/({node_score},{local_node_score})"
if node.is_end: if node.is_end:
dot.node(str(node.id), label=label, **final_state_attr) dot.node(str(node.id), label=label, **final_state_attr)
else: else:
@ -391,6 +379,7 @@ if __name__ == "__main__":
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE" "HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE" "SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE" "HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO" "HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S" "DHRHISQ": 4, # "HIS", "S"