aligned with the beam search approach for algn exp

This commit is contained in:
JinZr 2023-07-28 17:25:28 +08:00
parent 18f9a1d319
commit 70d603dc28

View File

@ -804,7 +804,7 @@ def deprecated_greedy_search_batch(
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
@ -909,7 +909,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
decoder_out.unsqueeze(1),
attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out),
None,
apply_attn=False,
apply_attn=True,
project_input=False,
)
# logits'shape (batch_size, 1, 1, vocab_size)