mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Harmonize whitening modules, adding them to 3 submodules and changing configuration on 2 others and location in NonlinAttention.
This commit is contained in:
parent
9ceb41acb4
commit
a6657e6b40
@ -1226,6 +1226,11 @@ class SelfAttention(nn.Module):
|
|||||||
embed_dim, bias=True,
|
embed_dim, bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
|
self.whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=15.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1259,6 +1264,7 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
x = self.whiten(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -1325,7 +1331,7 @@ class AttentionSqueeze(nn.Module):
|
|||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=10.0,
|
whitening_limit=15.0,
|
||||||
prob=(0.01, 0.1),
|
prob=(0.01, 0.1),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1382,7 +1388,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||||
initial_scale=0.01)
|
initial_scale=0.01)
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=10.0,
|
whitening_limit=15.0,
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
@ -1420,19 +1426,17 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
min_abs=0.2, max_abs=10.0,
|
min_abs=0.2, max_abs=10.0,
|
||||||
min_prob=0.05,
|
min_prob=0.05,
|
||||||
)
|
)
|
||||||
# give it a high limit, because it is quite high-dimensional and is
|
|
||||||
# a projection of a lower-dimensional embedding.
|
|
||||||
self.whiten = Whiten(num_groups=2,
|
|
||||||
whitening_limit=20.0,
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.01)
|
|
||||||
|
|
||||||
|
|
||||||
self.activation = Identity() # for diagnostics.
|
self.activation = Identity() # for diagnostics.
|
||||||
self.out_proj = ScaledLinear(channels, channels,
|
self.out_proj = ScaledLinear(channels, channels,
|
||||||
bias=True,
|
bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
|
self.whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=15.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -1447,7 +1451,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
a Tensor with the same shape as x
|
a Tensor with the same shape as x
|
||||||
"""
|
"""
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
x = self.whiten(x)
|
|
||||||
v, s = x.chunk(2, dim=-1)
|
v, s = x.chunk(2, dim=-1)
|
||||||
|
|
||||||
if self.training and random.random() < 0.02:
|
if self.training and random.random() < 0.02:
|
||||||
@ -1472,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
x = self.activation(x) # diagnostics only, it's the identity.
|
x = self.activation(x) # diagnostics only, it's the identity.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
x = self.whiten(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1549,6 +1554,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
initial_scale=0.05,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=15.0,
|
||||||
|
prob=(0.01, 0.1),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
@ -1584,8 +1595,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return x.permute(2, 0, 1)
|
x = x.permute(2, 0, 1)
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/2 length).
|
"""Convolutional 2D subsampling (to 1/2 length).
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user