mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Update padding modified beam search (#1217)
This commit is contained in:
parent
3b5645f594
commit
9a47c08d08
@ -1008,7 +1008,7 @@ def modified_beam_search(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=None if context_graph is None else context_graph.root,
|
||||
timestamp=[],
|
||||
@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search(
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
@ -1753,7 +1753,11 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[]
|
||||
)
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state_cost=NgramLmStateCost(ngram_lm),
|
||||
)
|
||||
@ -2446,7 +2450,7 @@ def modified_beam_search_LODR(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states, # state of the NN LM
|
||||
lm_score=init_score.reshape(-1),
|
||||
@ -2709,7 +2713,7 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
state=init_states,
|
||||
lm_score=init_score.reshape(-1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user