Replace Pooling module with ModifiedSEModule
This commit is contained in:
parent
4da4a3a5df
commit
e08f5c1bce
@ -824,7 +824,7 @@ class MaxEig(torch.nn.Module):
|
||||
|
||||
self.min_prob = min_prob
|
||||
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
||||
# We'll regress this towards prob, each tiem we try to apply it and it is not
|
||||
# We'll regress this towards prob, each time we try to apply it and it is not
|
||||
# active.
|
||||
self.cur_prob = 1.0
|
||||
|
||||
|
||||
@ -331,7 +331,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
d_model, attention_dim, nhead, pos_dim, dropout=0.0,
|
||||
)
|
||||
|
||||
self.pooling = PoolingModule(d_model)
|
||||
self.squeeze_excite = ModifiedSEModule(d_model)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
feedforward_dim,
|
||||
@ -433,8 +433,8 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
# pooling module
|
||||
if torch.jit.is_scripting() or random.random() > dynamic_dropout:
|
||||
src = src + self.pooling(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
src = src + self.squeeze_excite(src,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
|
||||
# multi-headed self-attention module
|
||||
use_self_attn = (random.random() > dynamic_dropout)
|
||||
@ -1429,15 +1429,50 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
|
||||
|
||||
class PoolingModule(nn.Module):
|
||||
class ModifiedSEModule(nn.Module):
|
||||
"""
|
||||
Averages the input over the time dimension and project with a square matrix.
|
||||
A modified version of Squeeze-and-Excite, where the nonliearity happens in the full dim and
|
||||
we just project to a small bottleneck dimension.
|
||||
"""
|
||||
def __init__(self,
|
||||
d_model: int):
|
||||
d_model: int,
|
||||
bottleneck_dim: int = 16):
|
||||
super().__init__()
|
||||
self.proj = ScaledLinear(d_model, d_model,
|
||||
initial_scale=0.1, bias=False)
|
||||
self.squeeze_proj = nn.Linear(d_model, d_model,
|
||||
bias=False)
|
||||
self.in_proj = nn.Linear(d_model, d_model,
|
||||
bias=False)
|
||||
|
||||
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if
|
||||
# we were training with a single very long audio sequence, or just 2 or 3 sequences
|
||||
# at a time. We make max_factor small to reduce the harm this could cause
|
||||
# (although when the grads get back past the averaging operation they would
|
||||
# be quite small and would probably not hurt the rest of the model much.)
|
||||
self.balancer = ActivationBalancer(
|
||||
d_model, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=0.95,
|
||||
max_abs=50.0,
|
||||
max_factor=0.01,
|
||||
min_prob=0.2,
|
||||
)
|
||||
self.activation = DoubleSwish()
|
||||
self.to_bottleneck_proj = ScaledLinear(d_model, bottleneck_dim)
|
||||
|
||||
self.bottleneck_balancer = ActivationBalancer(
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=0.95,
|
||||
max_abs=5.0,
|
||||
min_abs=0.5,
|
||||
max_factor=0.01,
|
||||
min_prob=0.2,
|
||||
)
|
||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, d_model)
|
||||
self.sigmoid = nn.Sigmoid() # make it a submodule for diagnostics purposes.
|
||||
|
||||
self.out_proj = ScaledLinear(d_model, d_model,
|
||||
bias=False, initial_scale=0.1)
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
@ -1459,9 +1494,22 @@ class PoolingModule(nn.Module):
|
||||
num_frames = x.shape[0]
|
||||
pooling_mask = 1.0 / num_frames
|
||||
|
||||
x = (x * pooling_mask).sum(dim=0, keepdim=True)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
squeezed = (x * pooling_mask).sum(dim=0, keepdim=True)
|
||||
squeezed = self.squeeze_proj(squeezed)
|
||||
squeezed = self.balancer(squeezed)
|
||||
squeezed = self.activation(squeezed)
|
||||
squeezed = self.to_bottleneck_proj(squeezed)
|
||||
squeezed = self.bottleneck_balancer(squeezed)
|
||||
squeezed = self.from_bottleneck_proj(squeezed)
|
||||
if random.random() < 0.05:
|
||||
# to stop a hopefully-unlikely failure mode where the inputs to the sigmoid
|
||||
# get too large and the grads get mostly too small.
|
||||
squeezed = penalize_abs_values_gt(squeezed, limit=10.0, penalty=1.0e-04)
|
||||
scales = self.sigmoid(squeezed)
|
||||
|
||||
x = self.in_proj(x)
|
||||
x = x * squeezed
|
||||
return self.out_proj(x)
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user