mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'zlm51' into zlm52
This commit is contained in:
commit
09294c0b51
@ -1754,7 +1754,11 @@ class NonlinAttention(nn.Module):
|
|||||||
|
|
||||||
self.identity1 = Identity() # for diagnostics.
|
self.identity1 = Identity() # for diagnostics.
|
||||||
self.identity2 = Identity() # for diagnostics.
|
self.identity2 = Identity() # for diagnostics.
|
||||||
self.identity3 = Identity() # for diagnostics.
|
|
||||||
|
|
||||||
|
# ensure the activations after multiplication don't get too large.
|
||||||
|
self.hidden_penalty = AbsValuePenalizer(
|
||||||
|
limit=10.0, penalty=1.0e-04, prob=0.1)
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(hidden_channels, channels,
|
self.out_proj = ScaledLinear(hidden_channels, channels,
|
||||||
bias=True,
|
bias=True,
|
||||||
@ -1815,7 +1819,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
y = self.identity2(y)
|
y = self.identity2(y)
|
||||||
x = x * y
|
x = x * y
|
||||||
x = self.identity3(x)
|
x = self.hidden_penalty(x)
|
||||||
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.whiten2(x)
|
x = self.whiten2(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user