mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
support context biasing in modified_beam_search_LODR
This commit is contained in:
parent
88067f7566
commit
e90563cdff
@ -2385,6 +2385,7 @@ def modified_beam_search_LODR(
|
|||||||
LODR_lm_scale: float,
|
LODR_lm_scale: float,
|
||||||
LM: LmScorer,
|
LM: LmScorer,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
"""This function implements LODR (https://arxiv.org/abs/2203.16776) with
|
||||||
`modified_beam_search`. It uses a bi-gram language model as the estimate
|
`modified_beam_search`. It uses a bi-gram language model as the estimate
|
||||||
@ -2453,6 +2454,7 @@ def modified_beam_search_LODR(
|
|||||||
state_cost=NgramLmStateCost(
|
state_cost=NgramLmStateCost(
|
||||||
LODR_lm
|
LODR_lm
|
||||||
), # state of the source domain ngram
|
), # state of the source domain ngram
|
||||||
|
context_state=None if context_graph is None else context_graph.root,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2598,8 +2600,17 @@ def modified_beam_search_LODR(
|
|||||||
|
|
||||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
|
|
||||||
|
context_score = 0
|
||||||
|
new_context_state = None if context_graph is None else hyp.context_state
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
|
if context_graph is not None:
|
||||||
|
(
|
||||||
|
context_score,
|
||||||
|
new_context_state,
|
||||||
|
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||||
|
|
||||||
ys.append(new_token)
|
ys.append(new_token)
|
||||||
state_cost = hyp.state_cost.forward_one_step(new_token)
|
state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||||
|
|
||||||
@ -2615,6 +2626,7 @@ def modified_beam_search_LODR(
|
|||||||
hyp_log_prob += (
|
hyp_log_prob += (
|
||||||
lm_score[new_token] * lm_scale
|
lm_score[new_token] * lm_scale
|
||||||
+ LODR_lm_scale * current_ngram_score
|
+ LODR_lm_scale * current_ngram_score
|
||||||
|
+ context_score
|
||||||
) # add the lm score
|
) # add the lm score
|
||||||
|
|
||||||
lm_score = scores[count]
|
lm_score = scores[count]
|
||||||
@ -2633,10 +2645,31 @@ def modified_beam_search_LODR(
|
|||||||
state=state,
|
state=state,
|
||||||
lm_score=lm_score,
|
lm_score=lm_score,
|
||||||
state_cost=state_cost,
|
state_cost=state_cost,
|
||||||
|
context_state=new_context_state,
|
||||||
)
|
)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
B = B + finalized_B
|
B = B + finalized_B
|
||||||
|
|
||||||
|
# finalize context_state, if the matched contexts do not reach final state
|
||||||
|
# we need to add the score on the corresponding backoff arc
|
||||||
|
if context_graph is not None:
|
||||||
|
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||||
|
for i, hyps in enumerate(B):
|
||||||
|
for hyp in list(hyps):
|
||||||
|
context_score, new_context_state = context_graph.finalize(
|
||||||
|
hyp.context_state
|
||||||
|
)
|
||||||
|
finalized_B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=hyp.ys,
|
||||||
|
log_prob=hyp.log_prob + context_score,
|
||||||
|
timestamp=hyp.timestamp,
|
||||||
|
context_state=new_context_state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
B = finalized_B
|
||||||
|
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
|||||||
@ -525,7 +525,6 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
context_graph=context_graph,
|
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -551,7 +550,6 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
LM=LM,
|
LM=LM,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
context_graph=context_graph,
|
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
|
||||||
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
lm_scale_list = [0.02 * i for i in range(2, 30)]
|
||||||
@ -564,7 +562,6 @@ def decode_one_batch(
|
|||||||
LODR_lm=ngram_lm,
|
LODR_lm=ngram_lm,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
lm_scale_list=lm_scale_list,
|
lm_scale_list=lm_scale_list,
|
||||||
context_graph=context_graph,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user