diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 97e259b40..3298568a3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2389,6 +2389,7 @@ def modified_beam_search_LODR( LODR_lm_scale: float, LM: LmScorer, beam: int = 4, + context_graph: Optional[ContextGraph] = None, ) -> List[List[int]]: """This function implements LODR (https://arxiv.org/abs/2203.16776) with `modified_beam_search`. It uses a bi-gram language model as the estimate @@ -2457,6 +2458,7 @@ def modified_beam_search_LODR( state_cost=NgramLmStateCost( LODR_lm ), # state of the source domain ngram + context_state=None if context_graph is None else context_graph.root, ) ) @@ -2602,8 +2604,17 @@ def modified_beam_search_LODR( hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] + + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + ys.append(new_token) state_cost = hyp.state_cost.forward_one_step(new_token) @@ -2619,6 +2630,7 @@ def modified_beam_search_LODR( hyp_log_prob += ( lm_score[new_token] * lm_scale + LODR_lm_scale * current_ngram_score + + context_score ) # add the lm score lm_score = scores[count] @@ -2637,10 +2649,31 @@ def modified_beam_search_LODR( state=state, lm_score=lm_score, state_cost=state_cost, + context_state=new_context_state, ) B[i].add(new_hyp) B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 2cc157e7a..3531d657f 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -97,6 +97,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -122,7 +123,7 @@ from beam_search import ( ) from train import add_model_arguments, get_model, get_params -from icefall import LmScorer, NgramLm +from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -215,6 +216,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - modified_beam_search_LODR - fast_beam_search - fast_beam_search_nbest - fast_beam_search_nbest_oracle @@ -251,7 +253,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. + Used only when --decoding-method is fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -285,7 +287,7 @@ def get_parser(): type=int, default=1, help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", + Used only when --decoding-method is greedy_search""", ) parser.add_argument( @@ -347,6 +349,27 @@ def get_parser(): help="ID of the backoff symbol in the ngram LM", ) + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) add_model_arguments(parser) return parser @@ -359,6 +382,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, LM: Optional[LmScorer] = None, ngram_lm=None, ngram_lm_scale: float = 0.0, @@ -388,7 +412,7 @@ def decode_one_batch( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. LM: A neural network language model. @@ -493,6 +517,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -515,6 +540,7 @@ def decode_one_batch( LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, + context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -578,16 +604,22 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): - ans = dict() - assert ans_dict is not None - for key, hyps in ans_dict.items(): - hyps = [sp.decode(hyp).split() for hyp in hyps] - ans[f"beam_size_{params.beam_size}_{key}"] = hyps - return ans + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -599,6 +631,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, LM: Optional[LmScorer] = None, ngram_lm=None, ngram_lm_scale: float = 0.0, @@ -618,7 +651,7 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -649,6 +682,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + context_graph=context_graph, word_table=word_table, batch=batch, LM=LM, @@ -744,6 +778,11 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: @@ -770,6 +809,12 @@ def main(): params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -952,6 +997,18 @@ def main(): decoding_graph = None word_table = None + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -976,6 +1033,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, LM=LM, ngram_lm=ngram_lm, ngram_lm_scale=ngram_lm_scale, diff --git a/icefall/context_graph.py b/icefall/context_graph.py index c78de30f6..01836df04 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -29,7 +29,7 @@ class ContextState: token: int, token_score: float, node_score: float, - local_node_score: float, + output_score: float, is_end: bool, ): """Create a ContextState. @@ -40,16 +40,15 @@ class ContextState: The id of the root node is always 0. token: The token id. - score: + token_score: The bonus for each token during decoding, which will hopefully boost the token up to survive beam search. node_score: The accumulated bonus from root of graph to current node, it will be used to calculate the score for fail arc. - local_node_score: - The accumulated bonus from last ``end_node``(node with is_end true) - to current_node, it will be used to calculate the score for fail arc. - Node: The local_node_score of a ``end_node`` is 0. + output_score: + The total scores of matched phrases, sum of the node_score of all + the output node for current node. is_end: True if current token is the end of a context. """ @@ -57,7 +56,7 @@ class ContextState: self.token = token self.token_score = token_score self.node_score = node_score - self.local_node_score = local_node_score + self.output_score = output_score self.is_end = is_end self.next = {} self.fail = None @@ -93,7 +92,7 @@ class ContextGraph: token=-1, token_score=0, node_score=0, - local_node_score=0, + output_score=0, is_end=False, ) self.root.fail = self.root @@ -131,6 +130,7 @@ class ContextGraph: output = None break node.output = output + node.output_score += 0 if output is None else output.output_score queue.append(node) def build(self, token_ids: List[List[int]]): @@ -153,14 +153,13 @@ class ContextGraph: if token not in node.next: self.num_nodes += 1 is_end = i == len(tokens) - 1 + node_score = node.node_score + self.context_score node.next[token] = ContextState( id=self.num_nodes, token=token, token_score=self.context_score, - node_score=node.node_score + self.context_score, - local_node_score=0 - if is_end - else (node.local_node_score + self.context_score), + node_score=node_score, + output_score=node_score if is_end else 0, is_end=is_end, ) node = node.next[token] @@ -186,8 +185,6 @@ class ContextGraph: if token in state.next: node = state.next[token] score = node.token_score - if state.is_end: - score += state.node_score else: # token not matched # We will trace along the fail arc until it matches the token or reaching @@ -202,14 +199,9 @@ class ContextGraph: node = node.next[token] # The score of the fail path - score = node.node_score - state.local_node_score + score = node.node_score - state.node_score assert node is not None - matched_score = 0 - output = node.output - while output is not None: - matched_score += output.node_score - output = output.output - return (score + matched_score, node) + return (score + node.output_score, node) def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize @@ -227,8 +219,6 @@ class ContextGraph: """ # The score of the fail arc score = -state.node_score - if state.is_end: - score = 0 return (score, self.root) def draw( @@ -307,10 +297,8 @@ class ContextGraph: for token, node in current_node.next.items(): if node.id not in seen: node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") - local_node_score = f"{node.local_node_score:.2f}".rstrip( - "0" - ).rstrip(".") - label = f"{node.id}/({node_score},{local_node_score})" + output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".") + label = f"{node.id}/({node_score}, {output_score})" if node.is_end: dot.node(str(node.id), label=label, **final_state_attr) else: @@ -391,6 +379,7 @@ if __name__ == "__main__": "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" "HISHE": 9, # "HIS", "S", "SHE", "HE" "SHED": 6, # "S", "SHE", "HE" + "SHELF": 6, # "S", "SHE", "HE" "HELL": 2, # "HE" "HELLO": 7, # "HE", "HELLO" "DHRHISQ": 4, # "HIS", "S"