mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add Whiten module, with whitening_limit=10.0, at output of ModifiedSEModule
This commit is contained in:
parent
a27670d097
commit
a2dbce2a9a
@ -1463,12 +1463,19 @@ class ModifiedSEModule(nn.Module):
|
|||||||
max_factor=0.01,
|
max_factor=0.01,
|
||||||
min_prob=0.2,
|
min_prob=0.2,
|
||||||
)
|
)
|
||||||
|
#self.bottleneck_norm = BasicNorm(bottleneck_dim)
|
||||||
|
|
||||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model)
|
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model)
|
||||||
self.sigmoid = nn.Sigmoid() # make it a submodule for diagnostics purposes.
|
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(d_model, d_model,
|
self.out_proj = ScaledLinear(d_model, d_model,
|
||||||
bias=False, initial_scale=0.1)
|
bias=False, initial_scale=0.1)
|
||||||
|
|
||||||
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=10.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -1497,16 +1504,16 @@ class ModifiedSEModule(nn.Module):
|
|||||||
squeezed = self.activation(squeezed)
|
squeezed = self.activation(squeezed)
|
||||||
squeezed = self.to_bottleneck_proj(squeezed)
|
squeezed = self.to_bottleneck_proj(squeezed)
|
||||||
squeezed = self.bottleneck_balancer(squeezed)
|
squeezed = self.bottleneck_balancer(squeezed)
|
||||||
|
#squeezed = self.bottleneck_norm(squeezed)
|
||||||
squeezed = self.from_bottleneck_proj(squeezed)
|
squeezed = self.from_bottleneck_proj(squeezed)
|
||||||
if random.random() < 0.05:
|
if random.random() < 0.05:
|
||||||
# to stop a hopefully-unlikely failure mode where the inputs to the sigmoid
|
# to stop a hopefully-unlikely failure mode where the inputs to the sigmoid
|
||||||
# get too large and the grads get mostly too small.
|
# get too large and the grads get mostly too small.
|
||||||
squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04)
|
squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04)
|
||||||
scales = self.sigmoid(squeezed)
|
|
||||||
|
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
x = x * squeezed
|
x = x * squeezed
|
||||||
return self.out_proj(x)
|
return self.out_whiten(self.out_proj(x))
|
||||||
|
|
||||||
|
|
||||||
class FeedforwardModule(nn.Module):
|
class FeedforwardModule(nn.Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user