mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
fixes on feat dim
This commit is contained in:
parent
3bd2e8e6cc
commit
7e5c7e6f77
@ -117,12 +117,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
lm_pruned: input of shape (batch_size * prune_range, seq_len, decoder_embed_dim)
|
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||||
am_pruned: input of shape (batch_size * prune_range, seq_len, encoder_embed_dim)
|
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||||
pos_emb: Positional embedding tensor, of shape (1, 2 * batch_size * prune_range - 1, pos_dim)
|
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||||
key_padding_mask: a bool tensor of shape (seq_len, batch_size * prune_range). Positions that
|
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||||
are True in this mask will be ignored as sources in the attention weighting.
|
that are True in this mask will be ignored as sources in the attention weighting.
|
||||||
attn_mask: mask of shape (batch_size * prune_range, batch_size * prune_range)
|
attn_mask: mask of shape (seq_len, seq_len)
|
||||||
or (seq_len, batch_size * prune_range, batch_size * prune_range),
|
or (seq_len, batch_size * prune_range, batch_size * prune_range),
|
||||||
interpreted as ([seq_len,] batch_size * prune_range, batch_size * prune_range)
|
interpreted as ([seq_len,] batch_size * prune_range, batch_size * prune_range)
|
||||||
saying which positions are allowed to attend to which other positions.
|
saying which positions are allowed to attend to which other positions.
|
||||||
@ -137,36 +137,36 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
|
|
||||||
(
|
(
|
||||||
b_p_dim,
|
|
||||||
seq_len,
|
seq_len,
|
||||||
|
b_p_dim,
|
||||||
_,
|
_,
|
||||||
) = lm_pruned.shape # actual dim: (batch * prune_range, seq_len, _)
|
) = lm_pruned.shape # actual dim: (seq_len, batch * prune_range, _)
|
||||||
|
|
||||||
query_dim = query_head_dim * num_heads
|
query_dim = query_head_dim * num_heads
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
q = lm_pruned[..., 0:query_dim] # (batch * prune_range, seq_len, query_dim)
|
q = lm_pruned[..., 0:query_dim] # (seq_len, batch * prune_range, query_dim)
|
||||||
k = am_pruned # (batch * prune_range, seq_len, query_dim)
|
k = am_pruned # (seq_len, batch * prune_range, query_dim)
|
||||||
# p is the position-encoding query
|
# p is the position-encoding query
|
||||||
p = lm_pruned[
|
p = lm_pruned[
|
||||||
..., query_dim:
|
..., query_dim:
|
||||||
] # (batch * prune_range, seq_len, pos_head_dim * num_heads)
|
] # (seq_len, batch * prune_range, pos_head_dim * num_heads)
|
||||||
assert p.shape[-1] == num_heads * pos_head_dim
|
assert p.shape[-1] == num_heads * pos_head_dim
|
||||||
|
|
||||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||||
|
|
||||||
q = q.reshape(b_p_dim, seq_len, num_heads, query_head_dim)
|
q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
|
||||||
p = p.reshape(b_p_dim, seq_len, num_heads, pos_head_dim)
|
p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim)
|
||||||
k = k.reshape(b_p_dim, seq_len, num_heads, query_head_dim)
|
k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
|
||||||
|
|
||||||
# time1 refers to target, time2 refers to source.
|
# time1 refers to target, time2 refers to source.
|
||||||
q = q.permute(
|
q = q.permute(
|
||||||
2, 1, 0, 3
|
2, 1, 0, 3
|
||||||
) # (head, seq_len, batch * prune_range, query_head_dim)
|
) # (head, batch * prune_range, seq_len, query_head_dim)
|
||||||
p = p.permute(2, 1, 0, 3) # (head, seq_len, batch * prune_range, pos_head_dim)
|
p = p.permute(2, 1, 0, 3) # (head, batch * prune_range, seq_len, pos_head_dim)
|
||||||
k = k.permute(2, 1, 3, 0) # (head, seq_len, d_k, batch * prune_range)
|
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, seq_len)
|
||||||
|
|
||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
@ -179,12 +179,14 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
if use_pos_scores:
|
if use_pos_scores:
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
|
print("pos_emb before proj", pos_emb.shape)
|
||||||
seq_len2 = 2 * b_p_dim - 1
|
seq_len2 = 2 * seq_len - 1
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||||
2, 0, 3, 1
|
2, 0, 3, 1
|
||||||
)
|
)
|
||||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||||
|
print("p", p.shape)
|
||||||
|
print("pos_emb after proj", pos_emb.shape)
|
||||||
|
|
||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
@ -193,24 +195,24 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
(num_heads, seq_len, time1, n) = pos_scores.shape
|
(num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
cols = torch.arange(b_p_dim)
|
cols = torch.arange(seq_len)
|
||||||
rows = rows.repeat(seq_len * num_heads).unsqueeze(-1)
|
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||||
indexes = rows + cols
|
indexes = rows + cols
|
||||||
pos_scores = pos_scores.reshape(-1, n)
|
pos_scores = pos_scores.reshape(-1, n)
|
||||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
pos_scores = pos_scores.reshape(num_heads, seq_len, time1, b_p_dim)
|
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len)
|
||||||
else:
|
else:
|
||||||
pos_scores = pos_scores.as_strided(
|
pos_scores = pos_scores.as_strided(
|
||||||
(num_heads, seq_len, b_p_dim, b_p_dim),
|
(num_heads, b_p_dim, seq_len, seq_len),
|
||||||
(
|
(
|
||||||
pos_scores.stride(0),
|
pos_scores.stride(0),
|
||||||
pos_scores.stride(1),
|
pos_scores.stride(1),
|
||||||
pos_scores.stride(2) - pos_scores.stride(3),
|
pos_scores.stride(2) - pos_scores.stride(3),
|
||||||
pos_scores.stride(3),
|
pos_scores.stride(3),
|
||||||
),
|
),
|
||||||
storage_offset=pos_scores.stride(3) * (b_p_dim - 1),
|
storage_offset=pos_scores.stride(3) * (seq_len - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
@ -234,7 +236,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
||||||
)
|
)
|
||||||
|
|
||||||
assert attn_scores.shape == (num_heads, seq_len, b_p_dim, b_p_dim)
|
assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert attn_mask.dtype == torch.bool
|
assert attn_mask.dtype == torch.bool
|
||||||
@ -246,8 +248,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
assert key_padding_mask.shape == (
|
assert key_padding_mask.shape == (
|
||||||
seq_len,
|
|
||||||
b_p_dim,
|
b_p_dim,
|
||||||
|
seq_len,
|
||||||
), key_padding_mask.shape
|
), key_padding_mask.shape
|
||||||
attn_scores = attn_scores.masked_fill(
|
attn_scores = attn_scores.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1),
|
key_padding_mask.unsqueeze(1),
|
||||||
@ -323,20 +325,24 @@ class AlignmentAttentionModule(nn.Module):
|
|||||||
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape
|
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape
|
||||||
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
|
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
|
||||||
|
|
||||||
# am_pruned : [B * prune_range, T, encoder_dim]
|
# am_pruned : [T, B * prune_range, encoder_dim]
|
||||||
# lm_pruned : [B * prune_range, T, decoder_dim]
|
# lm_pruned : [T, B * prune_range, decoder_dim]
|
||||||
am_pruned = am_pruned.transpose(1, 0).reshape(
|
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
|
||||||
batch_size * prune_range, T, encoder_dim
|
T, batch_size * prune_range, encoder_dim
|
||||||
)
|
)
|
||||||
lm_pruned = lm_pruned.transpose(1, 0).reshape(
|
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
|
||||||
batch_size * prune_range, T, decoder_dim
|
T, batch_size * prune_range, decoder_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
pos_emb = self.pos_encode(am_pruned)
|
pos_emb = self.pos_encode(merged_am_pruned)
|
||||||
|
|
||||||
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
|
attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb)
|
||||||
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
|
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
|
||||||
return label_level_am_representation.reshape(batch_size, T, prune_range, encoder_dim)
|
# (T, batch_size * prune_range, encoder_dim)
|
||||||
|
|
||||||
|
return label_level_am_representation \
|
||||||
|
.reshape(T, batch_size, prune_range, encoder_dim) \
|
||||||
|
.permute(1, 0, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user