mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
delete copy_query
This commit is contained in:
parent
8a7c43f3f3
commit
24d6565126
@ -438,7 +438,7 @@ class DecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return self.bypass_scale
|
||||
if random.random() < 0.1:
|
||||
# ensure we get grads if self.bypass_scale becomes out of range
|
||||
@ -565,11 +565,6 @@ class MultiHeadedAttention(nn.Module):
|
||||
grad_scale=0.025,
|
||||
)
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option.
|
||||
# they only copy their inputs.
|
||||
self.copy_pos_query = Identity()
|
||||
self.copy_query = Identity()
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
@ -597,7 +592,6 @@ class MultiHeadedAttention(nn.Module):
|
||||
k = self.linear_k(key)
|
||||
v = self.linear_v(value)
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
k = self.whiten_k(k) # does nothing in the forward pass.
|
||||
v = self.whiten_v(v) # does nothing in the forward pass.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user