mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
a runnable version of decoding scripts
This commit is contained in:
parent
634931cb61
commit
18f9a1d319
@ -823,10 +823,10 @@ def deprecated_greedy_search_batch(
|
|||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
)
|
)
|
||||||
print(current_encoder_out)
|
# print(current_encoder_out)
|
||||||
print(decoder_out.unsqueeze(1))
|
# print(decoder_out.unsqueeze(1))
|
||||||
print(logits)
|
# print(logits)
|
||||||
exit()
|
# exit()
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
assert logits.ndim == 2, logits.shape
|
assert logits.ndim == 2, logits.shape
|
||||||
@ -878,7 +878,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
hyps = [[blank_id] * context_size for _ in range(batch_size)]
|
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(batch_size)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
@ -893,6 +893,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
# encoder_out_for_attn = encoder_out.unsqueeze(2)
|
# encoder_out_for_attn = encoder_out.unsqueeze(2)
|
||||||
|
|
||||||
# decoder_out: (batch_size, 1, decoder_out_dim)
|
# decoder_out: (batch_size, 1, decoder_out_dim)
|
||||||
|
# emitted = False
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
@ -906,7 +907,7 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
decoder_out.unsqueeze(1),
|
decoder_out.unsqueeze(1),
|
||||||
torch.zeros_like(current_encoder_out),
|
attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out),
|
||||||
None,
|
None,
|
||||||
apply_attn=False,
|
apply_attn=False,
|
||||||
project_input=False,
|
project_input=False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user