diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index fd59d4b7f..68a39fa65 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2385,6 +2385,7 @@ def modified_beam_search_LODR( LODR_lm_scale: float, LM: LmScorer, beam: int = 4, + context_graph: Optional[ContextGraph] = None, ) -> List[List[int]]: """This function implements LODR (https://arxiv.org/abs/2203.16776) with `modified_beam_search`. It uses a bi-gram language model as the estimate @@ -2453,6 +2454,7 @@ def modified_beam_search_LODR( state_cost=NgramLmStateCost( LODR_lm ), # state of the source domain ngram + context_state=None if context_graph is None else context_graph.root, ) ) @@ -2598,8 +2600,17 @@ def modified_beam_search_LODR( hyp_log_prob = topk_log_probs[k] # get score of current hyp new_token = topk_token_indexes[k] + + context_score = 0 + new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): + if context_graph is not None: + ( + context_score, + new_context_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + ys.append(new_token) state_cost = hyp.state_cost.forward_one_step(new_token) @@ -2615,6 +2626,7 @@ def modified_beam_search_LODR( hyp_log_prob += ( lm_score[new_token] * lm_scale + LODR_lm_scale * current_ngram_score + + context_score ) # add the lm score lm_score = scores[count] @@ -2633,10 +2645,31 @@ def modified_beam_search_LODR( state=state, lm_score=lm_score, state_cost=state_cost, + context_state=new_context_state, ) B[i].add(new_hyp) B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + context_score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + B = finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 1da5b2669..a5cdeffe5 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -525,7 +525,6 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, LM=LM, - context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -551,7 +550,6 @@ def decode_one_batch( beam=params.beam_size, LM=LM, lm_scale_list=lm_scale_list, - context_graph=context_graph, ) elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": lm_scale_list = [0.02 * i for i in range(2, 30)] @@ -564,7 +562,6 @@ def decode_one_batch( LODR_lm=ngram_lm, sp=sp, lm_scale_list=lm_scale_list, - context_graph=context_graph, ) else: batch_size = encoder_out.size(0)