fixes on feat dim

This commit is contained in:
zr_jin 2023-07-23 20:19:19 +08:00
parent 3bd2e8e6cc
commit 7e5c7e6f77

View File

@ -117,12 +117,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
) -> Tensor: ) -> Tensor:
r""" r"""
Args: Args:
lm_pruned: input of shape (batch_size * prune_range, seq_len, decoder_embed_dim) lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (batch_size * prune_range, seq_len, encoder_embed_dim) am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2 * batch_size * prune_range - 1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (seq_len, batch_size * prune_range). Positions that key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
are True in this mask will be ignored as sources in the attention weighting. that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (batch_size * prune_range, batch_size * prune_range) attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range), or (seq_len, batch_size * prune_range, batch_size * prune_range),
interpreted as ([seq_len,] batch_size * prune_range, batch_size * prune_range) interpreted as ([seq_len,] batch_size * prune_range, batch_size * prune_range)
saying which positions are allowed to attend to which other positions. saying which positions are allowed to attend to which other positions.
@ -137,36 +137,36 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
num_heads = self.num_heads num_heads = self.num_heads
( (
b_p_dim,
seq_len, seq_len,
b_p_dim,
_, _,
) = lm_pruned.shape # actual dim: (batch * prune_range, seq_len, _) ) = lm_pruned.shape # actual dim: (seq_len, batch * prune_range, _)
query_dim = query_head_dim * num_heads query_dim = query_head_dim * num_heads
# self-attention # self-attention
q = lm_pruned[..., 0:query_dim] # (batch * prune_range, seq_len, query_dim) q = lm_pruned[..., 0:query_dim] # (seq_len, batch * prune_range, query_dim)
k = am_pruned # (batch * prune_range, seq_len, query_dim) k = am_pruned # (seq_len, batch * prune_range, query_dim)
# p is the position-encoding query # p is the position-encoding query
p = lm_pruned[ p = lm_pruned[
..., query_dim: ..., query_dim:
] # (batch * prune_range, seq_len, pos_head_dim * num_heads) ] # (seq_len, batch * prune_range, pos_head_dim * num_heads)
assert p.shape[-1] == num_heads * pos_head_dim assert p.shape[-1] == num_heads * pos_head_dim
q = self.copy_query(q) # for diagnostics only, does nothing. q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(b_p_dim, seq_len, num_heads, query_head_dim) q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
p = p.reshape(b_p_dim, seq_len, num_heads, pos_head_dim) p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim)
k = k.reshape(b_p_dim, seq_len, num_heads, query_head_dim) k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
# time1 refers to target, time2 refers to source. # time1 refers to target, time2 refers to source.
q = q.permute( q = q.permute(
2, 1, 0, 3 2, 1, 0, 3
) # (head, seq_len, batch * prune_range, query_head_dim) ) # (head, batch * prune_range, seq_len, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, seq_len, batch * prune_range, pos_head_dim) p = p.permute(2, 1, 0, 3) # (head, batch * prune_range, seq_len, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, seq_len, d_k, batch * prune_range) k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, seq_len)
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k)
@ -179,12 +179,14 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if use_pos_scores: if use_pos_scores:
pos_emb = self.linear_pos(pos_emb) pos_emb = self.linear_pos(pos_emb)
print("pos_emb before proj", pos_emb.shape)
seq_len2 = 2 * b_p_dim - 1 seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
2, 0, 3, 1 2, 0, 3, 1
) )
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
print("p", p.shape)
print("pos_emb after proj", pos_emb.shape)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]
@ -193,24 +195,24 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# to absolute position. I don't know whether I might have got the time-offsets backwards or # to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be. # not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing(): if torch.jit.is_tracing():
(num_heads, seq_len, time1, n) = pos_scores.shape (num_heads, b_p_dim, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1) rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(b_p_dim) cols = torch.arange(seq_len)
rows = rows.repeat(seq_len * num_heads).unsqueeze(-1) rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
indexes = rows + cols indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n) pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, seq_len, time1, b_p_dim) pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len)
else: else:
pos_scores = pos_scores.as_strided( pos_scores = pos_scores.as_strided(
(num_heads, seq_len, b_p_dim, b_p_dim), (num_heads, b_p_dim, seq_len, seq_len),
( (
pos_scores.stride(0), pos_scores.stride(0),
pos_scores.stride(1), pos_scores.stride(1),
pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(2) - pos_scores.stride(3),
pos_scores.stride(3), pos_scores.stride(3),
), ),
storage_offset=pos_scores.stride(3) * (b_p_dim - 1), storage_offset=pos_scores.stride(3) * (seq_len - 1),
) )
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
@ -234,7 +236,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
) )
assert attn_scores.shape == (num_heads, seq_len, b_p_dim, b_p_dim) assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len)
if attn_mask is not None: if attn_mask is not None:
assert attn_mask.dtype == torch.bool assert attn_mask.dtype == torch.bool
@ -246,8 +248,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == ( assert key_padding_mask.shape == (
seq_len,
b_p_dim, b_p_dim,
seq_len,
), key_padding_mask.shape ), key_padding_mask.shape
attn_scores = attn_scores.masked_fill( attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1), key_padding_mask.unsqueeze(1),
@ -323,20 +325,24 @@ class AlignmentAttentionModule(nn.Module):
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape (batch_size, T, prune_range, encoder_dim) = am_pruned.shape
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape (batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
# am_pruned : [B * prune_range, T, encoder_dim] # am_pruned : [T, B * prune_range, encoder_dim]
# lm_pruned : [B * prune_range, T, decoder_dim] # lm_pruned : [T, B * prune_range, decoder_dim]
am_pruned = am_pruned.transpose(1, 0).reshape( merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
batch_size * prune_range, T, encoder_dim T, batch_size * prune_range, encoder_dim
) )
lm_pruned = lm_pruned.transpose(1, 0).reshape( merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
batch_size * prune_range, T, decoder_dim T, batch_size * prune_range, decoder_dim
) )
pos_emb = self.pos_encode(am_pruned) pos_emb = self.pos_encode(merged_am_pruned)
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb)
label_level_am_representation = self.cross_attn(am_pruned, attn_weights) label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
return label_level_am_representation.reshape(batch_size, T, prune_range, encoder_dim) # (T, batch_size * prune_range, encoder_dim)
return label_level_am_representation \
.reshape(T, batch_size, prune_range, encoder_dim) \
.permute(1, 0, 2, 3)
if __name__ == "__main__": if __name__ == "__main__":