mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Changes to whitening modules for memory efficiency, moving them inside; increase their prob.
This commit is contained in:
parent
de73e2e424
commit
35f0ea0015
@ -1041,6 +1041,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.pos_head_dim = pos_head_dim
|
self.pos_head_dim = pos_head_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||||
|
self.name = None # will be overwritten in training code; for diagnostics.
|
||||||
|
|
||||||
key_head_dim = query_head_dim
|
key_head_dim = query_head_dim
|
||||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||||
@ -1202,7 +1203,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_weights = attn_weights.to(torch.float32)
|
attn_weights = attn_weights.to(torch.float32)
|
||||||
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
|
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
|
||||||
dim=-1).mean(dim=(1,2))
|
dim=-1).mean(dim=(1,2))
|
||||||
logging.info(f"attn_weights_entropy = {attn_weights_entropy}")
|
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
@ -1328,17 +1329,17 @@ class AttentionSqueeze(nn.Module):
|
|||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.05,
|
min_prob=0.05,
|
||||||
)
|
)
|
||||||
|
self.activation_whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
|
||||||
prob=(0.01, 0.1),
|
|
||||||
grad_scale=0.01)
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
attn_weights: Tensor):
|
attn_weights: Tensor):
|
||||||
@ -1367,11 +1368,11 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
x = self.activation_balancer(x)
|
x = self.activation_balancer(x)
|
||||||
|
x = self.activation_whiten(x)
|
||||||
scales = self.scale_balancer(scales)
|
scales = self.scale_balancer(scales)
|
||||||
x = x * scales
|
x = x * scales
|
||||||
x = self.activation(x) # Identity only. For diagnostics.
|
x = self.activation(x) # Identity only. For diagnostics.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.out_whiten(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1548,6 +1549,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
self.activation = DoubleSwish()
|
self.activation = DoubleSwish()
|
||||||
|
|
||||||
|
self.whiten = Whiten(num_groups=1,
|
||||||
|
whitening_limit=_whitening_schedule(7.5),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
self.pointwise_conv2 = ScaledConv1d(
|
self.pointwise_conv2 = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -1558,11 +1564,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
initial_scale=0.05,
|
initial_scale=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
|
||||||
prob=(0.01, 0.1),
|
|
||||||
grad_scale=0.01)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -1597,10 +1598,13 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = self.deriv_balancer2(x)
|
x = self.deriv_balancer2(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.whiten(x) # (batch, time, channel)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
x = x.permute(2, 0, 1)
|
x = x.permute(2, 0, 1) # (time, batch, channel)
|
||||||
x = self.out_whiten(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user