Update padding modified beam search (#1217)

This commit is contained in:
Erwan Zerhouni 2023-08-14 16:10:50 +02:00 committed by GitHub
parent 3b5645f594
commit 9a47c08d08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1008,7 +1008,7 @@ def modified_beam_search(
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=None if context_graph is None else context_graph.root,
timestamp=[],
@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore(
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR(
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search(
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
@ -1753,7 +1753,11 @@ def beam_search(
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
B.add(
Hypothesis(
ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[]
)
)
max_sym_per_utt = 20000
@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring(
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm),
)
@ -2446,7 +2450,7 @@ def modified_beam_search_LODR(
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states, # state of the NN LM
lm_score=init_score.reshape(-1),
@ -2709,7 +2713,7 @@ def modified_beam_search_lm_shallow_fusion(
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
ys=[-1] * (context_size - 1) + [blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),