support context biasing in modified_beam_search_LODR

This commit is contained in:
pkufool 2023-08-09 11:39:12 +08:00
parent 88067f7566
commit e90563cdff
2 changed files with 33 additions and 3 deletions

View File

@ -2385,6 +2385,7 @@ def modified_beam_search_LODR(
LODR_lm_scale: float, LODR_lm_scale: float,
LM: LmScorer, LM: LmScorer,
beam: int = 4, beam: int = 4,
context_graph: Optional[ContextGraph] = None,
) -> List[List[int]]: ) -> List[List[int]]:
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with """This function implements LODR (https://arxiv.org/abs/2203.16776) with
`modified_beam_search`. It uses a bi-gram language model as the estimate `modified_beam_search`. It uses a bi-gram language model as the estimate
@ -2453,6 +2454,7 @@ def modified_beam_search_LODR(
state_cost=NgramLmStateCost( state_cost=NgramLmStateCost(
LODR_lm LODR_lm
), # state of the source domain ngram ), # state of the source domain ngram
context_state=None if context_graph is None else context_graph.root,
) )
) )
@ -2598,8 +2600,17 @@ def modified_beam_search_LODR(
hyp_log_prob = topk_log_probs[k] # get score of current hyp hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
if context_graph is not None:
(
context_score,
new_context_state,
) = context_graph.forward_one_step(hyp.context_state, new_token)
ys.append(new_token) ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token) state_cost = hyp.state_cost.forward_one_step(new_token)
@ -2615,6 +2626,7 @@ def modified_beam_search_LODR(
hyp_log_prob += ( hyp_log_prob += (
lm_score[new_token] * lm_scale lm_score[new_token] * lm_scale
+ LODR_lm_scale * current_ngram_score + LODR_lm_scale * current_ngram_score
+ context_score
) # add the lm score ) # add the lm score
lm_score = scores[count] lm_score = scores[count]
@ -2633,10 +2645,31 @@ def modified_beam_search_LODR(
state=state, state=state,
lm_score=lm_score, lm_score=lm_score,
state_cost=state_cost, state_cost=state_cost,
context_state=new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
context_score, new_context_state = context_graph.finalize(
hyp.context_state
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + context_score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -525,7 +525,6 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
LM=LM, LM=LM,
context_graph=context_graph,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -551,7 +550,6 @@ def decode_one_batch(
beam=params.beam_size, beam=params.beam_size,
LM=LM, LM=LM,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
context_graph=context_graph,
) )
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)] lm_scale_list = [0.02 * i for i in range(2, 30)]
@ -564,7 +562,6 @@ def decode_one_batch(
LODR_lm=ngram_lm, LODR_lm=ngram_lm,
sp=sp, sp=sp,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
context_graph=context_graph,
) )
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)