mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
updated code for decoding
This commit is contained in:
parent
b9a8f107a2
commit
d70f6e21f2
@ -825,13 +825,15 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
# current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
current_encoder_out = model.joiner.label_level_am_attention(
|
current_encoder_out = model.joiner.label_level_am_attention(
|
||||||
encoder_out[:, : t + 1, :].unsqueeze(2),
|
encoder_out.unsqueeze(2),
|
||||||
decoder_out.unsqueeze(2),
|
decoder_out.unsqueeze(2),
|
||||||
encoder_out_lens,
|
# encoder_out_lens,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
decoder_out.unsqueeze(1),
|
decoder_out.unsqueeze(1),
|
||||||
|
None,
|
||||||
apply_attn=False,
|
apply_attn=False,
|
||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
|
@ -73,7 +73,7 @@ class CrossAttention(nn.Module):
|
|||||||
batch_size,
|
batch_size,
|
||||||
lm_seq_len,
|
lm_seq_len,
|
||||||
am_seq_len,
|
am_seq_len,
|
||||||
), f"{attn_weights.shape}"
|
), f"{attn_weights.shape} {x.shape}"
|
||||||
|
|
||||||
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
||||||
# print("projected x.shape", x.shape)
|
# print("projected x.shape", x.shape)
|
||||||
@ -406,19 +406,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
# (batch, max_len)
|
# (batch, max_len)
|
||||||
|
|
||||||
key_padding_mask = (
|
if b_p_dim == key_padding_mask.shape[0] * self.prune_range:
|
||||||
key_padding_mask.unsqueeze(1)
|
key_padding_mask = (
|
||||||
.expand(
|
key_padding_mask.unsqueeze(1)
|
||||||
key_padding_mask.shape[0], # b
|
.expand(
|
||||||
self.prune_range,
|
key_padding_mask.shape[0], # b
|
||||||
key_padding_mask.shape[1], # l
|
self.prune_range,
|
||||||
|
key_padding_mask.shape[1], # l
|
||||||
|
)
|
||||||
|
.reshape(b_p_dim, am_seq_len)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.unsqueeze(0)
|
||||||
)
|
)
|
||||||
.reshape(b_p_dim, am_seq_len)
|
# (1, b * p, 1, T)
|
||||||
.unsqueeze(1)
|
else:
|
||||||
.unsqueeze(0)
|
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(0)
|
||||||
)
|
# (1, b, 1, T)
|
||||||
# print(key_padding_mask.shape)
|
# print(key_padding_mask.shape)
|
||||||
# (1, b * p, 1, T)
|
|
||||||
|
|
||||||
attn_scores = attn_scores.masked_fill(
|
attn_scores = attn_scores.masked_fill(
|
||||||
key_padding_mask,
|
key_padding_mask,
|
||||||
@ -492,7 +496,7 @@ class AlignmentAttentionModule(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
|
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None
|
||||||
# (batch, max_len)
|
# (batch, max_len)
|
||||||
|
|
||||||
if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
|
if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user