mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
delete copy_query
This commit is contained in:
parent
8a7c43f3f3
commit
24d6565126
@ -438,7 +438,7 @@ class DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_bypass_scale(self):
|
def get_bypass_scale(self):
|
||||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
if random.random() < 0.1:
|
if random.random() < 0.1:
|
||||||
# ensure we get grads if self.bypass_scale becomes out of range
|
# ensure we get grads if self.bypass_scale becomes out of range
|
||||||
@ -565,11 +565,6 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
grad_scale=0.025,
|
grad_scale=0.025,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the following are for diagnosics only, see --print-diagnostics option.
|
|
||||||
# they only copy their inputs.
|
|
||||||
self.copy_pos_query = Identity()
|
|
||||||
self.copy_query = Identity()
|
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||||
)
|
)
|
||||||
@ -597,7 +592,6 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
k = self.linear_k(key)
|
k = self.linear_k(key)
|
||||||
v = self.linear_v(value)
|
v = self.linear_v(value)
|
||||||
|
|
||||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
|
||||||
k = self.whiten_k(k) # does nothing in the forward pass.
|
k = self.whiten_k(k) # does nothing in the forward pass.
|
||||||
v = self.whiten_v(v) # does nothing in the forward pass.
|
v = self.whiten_v(v) # does nothing in the forward pass.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user