mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add another whitening module, move balancer to output.
This commit is contained in:
parent
8859177bfa
commit
5f5d02ed0c
@ -1496,7 +1496,12 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
min_abs=0.01,
|
min_abs=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten1 = Whiten(num_groups=1,
|
||||||
|
whitening_limit=_whitening_schedule(5.0),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
self.whiten2 = Whiten(num_groups=1,
|
||||||
whitening_limit=_whitening_schedule(5.0),
|
whitening_limit=_whitening_schedule(5.0),
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
@ -1539,10 +1544,11 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = torch.matmul(attn_weights, x)
|
x = torch.matmul(attn_weights, x)
|
||||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||||
|
x = self.whiten1(x)
|
||||||
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.balancer2(x)
|
x = self.balancer2(x)
|
||||||
x = self.whiten(x)
|
x = self.whiten2(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user