mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add Whiten module after squeeze_proj.
This commit is contained in:
parent
11cb30bf49
commit
a9c384e69e
@ -1438,6 +1438,13 @@ class ModifiedSEModule(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.squeeze_proj = nn.Linear(d_model, d_model,
|
self.squeeze_proj = nn.Linear(d_model, d_model,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
|
# caution: this won't work well if the batch size is extremely small.
|
||||||
|
self.squeeze_whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=10.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
self.in_proj = nn.Linear(d_model, d_model,
|
self.in_proj = nn.Linear(d_model, d_model,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
@ -1494,6 +1501,7 @@ class ModifiedSEModule(nn.Module):
|
|||||||
|
|
||||||
squeezed = (x * pooling_mask).sum(dim=0, keepdim=True)
|
squeezed = (x * pooling_mask).sum(dim=0, keepdim=True)
|
||||||
squeezed = self.squeeze_proj(squeezed)
|
squeezed = self.squeeze_proj(squeezed)
|
||||||
|
squeezed = self.squeeze_whiten(squeezed)
|
||||||
squeezed = self.balancer(squeezed)
|
squeezed = self.balancer(squeezed)
|
||||||
squeezed = self.activation(squeezed)
|
squeezed = self.activation(squeezed)
|
||||||
squeezed = self.to_bottleneck_proj(squeezed)
|
squeezed = self.to_bottleneck_proj(squeezed)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user