Add Whiten module after squeeze_proj.

This commit is contained in:
Daniel Povey 2022-11-03 19:04:34 +08:00
parent 11cb30bf49
commit a9c384e69e

View File

@ -1438,6 +1438,13 @@ class ModifiedSEModule(nn.Module):
super().__init__() super().__init__()
self.squeeze_proj = nn.Linear(d_model, d_model, self.squeeze_proj = nn.Linear(d_model, d_model,
bias=False) bias=False)
# caution: this won't work well if the batch size is extremely small.
self.squeeze_whiten = Whiten(num_groups=1,
whitening_limit=10.0,
prob=(0.025, 0.25),
grad_scale=0.01)
self.in_proj = nn.Linear(d_model, d_model, self.in_proj = nn.Linear(d_model, d_model,
bias=False) bias=False)
@ -1494,6 +1501,7 @@ class ModifiedSEModule(nn.Module):
squeezed = (x * pooling_mask).sum(dim=0, keepdim=True) squeezed = (x * pooling_mask).sum(dim=0, keepdim=True)
squeezed = self.squeeze_proj(squeezed) squeezed = self.squeeze_proj(squeezed)
squeezed = self.squeeze_whiten(squeezed)
squeezed = self.balancer(squeezed) squeezed = self.balancer(squeezed)
squeezed = self.activation(squeezed) squeezed = self.activation(squeezed)
squeezed = self.to_bottleneck_proj(squeezed) squeezed = self.to_bottleneck_proj(squeezed)