delete copy_query

This commit is contained in:
yaozengwei 2023-01-17 10:49:56 +08:00
parent 8a7c43f3f3
commit 24d6565126

View File

@ -438,7 +438,7 @@ class DecoderLayer(nn.Module):
) )
def get_bypass_scale(self): def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return self.bypass_scale return self.bypass_scale
if random.random() < 0.1: if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range # ensure we get grads if self.bypass_scale becomes out of range
@ -565,11 +565,6 @@ class MultiHeadedAttention(nn.Module):
grad_scale=0.025, grad_scale=0.025,
) )
# the following are for diagnosics only, see --print-diagnostics option.
# they only copy their inputs.
self.copy_pos_query = Identity()
self.copy_query = Identity()
self.out_proj = ScaledLinear( self.out_proj = ScaledLinear(
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
) )
@ -597,7 +592,6 @@ class MultiHeadedAttention(nn.Module):
k = self.linear_k(key) k = self.linear_k(key)
v = self.linear_v(value) v = self.linear_v(value)
q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_k(k) # does nothing in the forward pass. k = self.whiten_k(k) # does nothing in the forward pass.
v = self.whiten_v(v) # does nothing in the forward pass. v = self.whiten_v(v) # does nothing in the forward pass.