delete copy_query

This commit is contained in:
yaozengwei 2023-01-17 10:49:56 +08:00
parent 8a7c43f3f3
commit 24d6565126

View File

@ -438,7 +438,7 @@ class DecoderLayer(nn.Module):
)
def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return self.bypass_scale
if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range
@ -565,11 +565,6 @@ class MultiHeadedAttention(nn.Module):
grad_scale=0.025,
)
# the following are for diagnosics only, see --print-diagnostics option.
# they only copy their inputs.
self.copy_pos_query = Identity()
self.copy_query = Identity()
self.out_proj = ScaledLinear(
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
@ -597,7 +592,6 @@ class MultiHeadedAttention(nn.Module):
k = self.linear_k(key)
v = self.linear_v(value)
q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_k(k) # does nothing in the forward pass.
v = self.whiten_v(v) # does nothing in the forward pass.