This commit is contained in:
Yifan Yang 2023-06-15 11:51:40 +08:00
parent aaec7c299f
commit 38e039881f

View File

@ -1490,16 +1490,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.name = None # will be overwritten in training code; for diagnostics. self.name = None # will be overwritten in training code; for diagnostics.
key_head_dim = query_head_dim key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads in_proj_dim_a = (query_head_dim + pos_head_dim) * num_heads
in_proj_dim_b = key_head_dim * num_heads
# the initial_scale is supposed to take over the "scaling" factor of # the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5 that has been used in previous forms of attention, # head_dim ** -0.5 that has been used in previous forms of attention,
# dividing it between the query and key. Note: this module is intended # dividing it between the query and key. Note: this module is intended
# to be used with the ScaledAdam optimizer; with most other optimizers, # to be used with the ScaledAdam optimizer; with most other optimizers,
# it would be necessary to apply the scaling factor in the forward function. # it would be necessary to apply the scaling factor in the forward function.
self.in_proj_a = ScaledLinear(embed_dim, in_proj_dim, bias=True, self.in_proj_a = ScaledLinear(embed_dim, in_proj_dim_a, bias=True,
initial_scale=query_head_dim**-0.25) initial_scale=query_head_dim**-0.25)
self.in_proj_b = ScaledLinear(embed_dim, in_proj_dim, bias=True, self.in_proj_b = ScaledLinear(embed_dim, in_proj_dim_b, bias=True,
initial_scale=query_head_dim**-0.25) initial_scale=query_head_dim**-0.25)
self.whiten_keys = Whiten(num_groups=num_heads, self.whiten_keys = Whiten(num_groups=num_heads,
@ -1567,9 +1568,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# self-attention # self-attention
q = x[...,0:query_dim] q = x[...,0:query_dim]
k = y[...,query_dim:2*query_dim] k = y
# p is the position-encoding query # p is the position-encoding query
p = x[...,2*query_dim:] p = x[...,query_dim:]
assert p.shape[-1] == num_heads * pos_head_dim assert p.shape[-1] == num_heads * pos_head_dim
q = self.copy_query(q) # for diagnostics only, does nothing. q = self.copy_query(q) # for diagnostics only, does nothing.