a runnable version of decoding scripts

This commit is contained in:
JinZr 2023-07-28 09:24:59 +08:00
parent 634931cb61
commit 18f9a1d319

View File

@ -823,10 +823,10 @@ def deprecated_greedy_search_batch(
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False current_encoder_out, decoder_out.unsqueeze(1), project_input=False
) )
print(current_encoder_out) # print(current_encoder_out)
print(decoder_out.unsqueeze(1)) # print(decoder_out.unsqueeze(1))
print(logits) # print(logits)
exit() # exit()
# logits'shape (batch_size, 1, 1, vocab_size) # logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape assert logits.ndim == 2, logits.shape
@ -878,7 +878,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)] hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(batch_size)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
@ -893,6 +893,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
# encoder_out_for_attn = encoder_out.unsqueeze(2) # encoder_out_for_attn = encoder_out.unsqueeze(2)
# decoder_out: (batch_size, 1, decoder_out_dim) # decoder_out: (batch_size, 1, decoder_out_dim)
# emitted = False
for t in range(T): for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
@ -906,7 +907,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
logits = model.joiner( logits = model.joiner(
current_encoder_out, current_encoder_out,
decoder_out.unsqueeze(1), decoder_out.unsqueeze(1),
torch.zeros_like(current_encoder_out), attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out),
None, None,
apply_attn=False, apply_attn=False,
project_input=False, project_input=False,