Harmonize whitening modules, adding them to 3 submodules and changing configuration on 2 others and location in NonlinAttention.

This commit is contained in:
Daniel Povey 2022-11-23 19:08:19 +08:00
parent 9ceb41acb4
commit a6657e6b40

View File

@ -1226,6 +1226,11 @@ class SelfAttention(nn.Module):
embed_dim, bias=True, embed_dim, bias=True,
initial_scale=0.05) initial_scale=0.05)
self.whiten = Whiten(num_groups=1,
whitening_limit=15.0,
prob=(0.025, 0.25),
grad_scale=0.01)
def forward( def forward(
self, self,
@ -1259,6 +1264,7 @@ class SelfAttention(nn.Module):
# returned value is of shape (seq_len, batch_size, embed_dim), like the input. # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x) x = self.out_proj(x)
x = self.whiten(x)
return x return x
@ -1325,7 +1331,7 @@ class AttentionSqueeze(nn.Module):
bias=False, initial_scale=0.05) bias=False, initial_scale=0.05)
self.out_whiten = Whiten(num_groups=1, self.out_whiten = Whiten(num_groups=1,
whitening_limit=10.0, whitening_limit=15.0,
prob=(0.01, 0.1), prob=(0.01, 0.1),
grad_scale=0.01) grad_scale=0.01)
@ -1382,7 +1388,7 @@ class FeedforwardModule(nn.Module):
self.out_proj = ScaledLinear(feedforward_dim, embed_dim, self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.01) initial_scale=0.01)
self.out_whiten = Whiten(num_groups=1, self.out_whiten = Whiten(num_groups=1,
whitening_limit=10.0, whitening_limit=15.0,
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
@ -1420,19 +1426,17 @@ class NonlinAttentionModule(nn.Module):
min_abs=0.2, max_abs=10.0, min_abs=0.2, max_abs=10.0,
min_prob=0.05, min_prob=0.05,
) )
# give it a high limit, because it is quite high-dimensional and is
# a projection of a lower-dimensional embedding.
self.whiten = Whiten(num_groups=2,
whitening_limit=20.0,
prob=(0.025, 0.25),
grad_scale=0.01)
self.activation = Identity() # for diagnostics. self.activation = Identity() # for diagnostics.
self.out_proj = ScaledLinear(channels, channels, self.out_proj = ScaledLinear(channels, channels,
bias=True, bias=True,
initial_scale=0.05) initial_scale=0.05)
self.whiten = Whiten(num_groups=1,
whitening_limit=15.0,
prob=(0.025, 0.25),
grad_scale=0.01)
def forward(self, def forward(self,
@ -1447,7 +1451,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
a Tensor with the same shape as x a Tensor with the same shape as x
""" """
x = self.in_proj(x) x = self.in_proj(x)
x = self.whiten(x)
v, s = x.chunk(2, dim=-1) v, s = x.chunk(2, dim=-1)
if self.training and random.random() < 0.02: if self.training and random.random() < 0.02:
@ -1472,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = self.activation(x) # diagnostics only, it's the identity. x = self.activation(x) # diagnostics only, it's the identity.
x = self.out_proj(x) x = self.out_proj(x)
x = self.whiten(x)
return x return x
@ -1549,6 +1554,12 @@ class ConvolutionModule(nn.Module):
initial_scale=0.05, initial_scale=0.05,
) )
self.out_whiten = Whiten(num_groups=1,
whitening_limit=15.0,
prob=(0.01, 0.1),
grad_scale=0.01)
def forward(self, def forward(self,
x: Tensor, x: Tensor,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
@ -1584,8 +1595,9 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1) x = x.permute(2, 0, 1)
x = self.out_whiten(x)
return x
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length). """Convolutional 2D subsampling (to 1/2 length).