diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 174ccf39e..ed3784a78 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -122,40 +122,6 @@ def _compute_sign_factor(x: Tensor, -class ActivationScaleBalancerFunction(torch.autograd.Function): - """ - This object is used in class ActivationBalancer when the user specified - min_positive=0, max_positive=1, so there are no constraints on the signs - of the activations and only the absolute value has a constraint. - """ - @staticmethod - def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = (x > 0) - ctx.save_for_backward(xgt0, sign_factor, scale_factor) - return x - - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) - - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, def random_cast_to_half(x: Tensor, @@ -465,6 +431,7 @@ class ActivationBalancer(torch.nn.Module): # loop. self.batch_count = 0 + # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels self.channel_dim = channel_dim self.min_positive = min_positive @@ -489,6 +456,7 @@ class ActivationBalancer(torch.nn.Module): prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0))) if random.random() < prob: + assert x.shape[self.channel_dim] == self.num_channels sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: sign_factor = _compute_sign_factor(x, self.channel_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index efcb25754..38394564e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1288,25 +1288,46 @@ class AttentionSqueeze(nn.Module): bottleneck_dim, bias=False) + + # the main reason for this balancer is to keep the bottleneck activations in a "reasonable" + # dynamic range, to avoid parameter-size 'drift' where to_bottleneck_proj gets large and from_bottleneck_proj + # gets small or vice versa. # Caution: this cannot work correctly with an extremeley small batch size, e.g. if # we were training with a single very long audio sequence, or just 2 or 3 sequences # at a time. We make max_factor small to reduce the harm this could cause # (although when the grads get back past the averaging operation they would # be quite small and would probably not hurt the rest of the model much.) - self.balancer = ActivationBalancer( - embed_dim, channel_dim=-1, + self.bottleneck_balancer = ActivationBalancer( + bottleneck_dim, channel_dim=-1, min_positive=0.05, max_positive=0.95, - min_abs=0.1, - max_abs=50.0, + min_abs=0.2, + max_abs=1.0, max_factor=0.02, - min_prob=0.2, + min_prob=0.1, + ) + + # the next two balancers are only to stop parameter-magnitude 'drift': we have + # too many degrees of freedom for the scales of the various activations. + # Make them run with very low probability, since only a small application of + # these balancers should be enough to stop such "drift"; and, for speed, + # put no limitation on the signs (so: min_positive=0, max_positive=1). + self.scale_balancer = ActivationBalancer( + embed_dim, channel_dim=-1, + min_positive=0.0, max_positive=1.0, + min_abs=0.2, max_abs=1.0, + min_prob=0.025, + ) + self.activation_balancer = ActivationBalancer( + embed_dim, channel_dim=-1, + min_positive=0.0, max_positive=1.0, + min_abs=0.2, max_abs=1.0, + min_prob=0.025, ) - self.activation = DoubleSwish() self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.out_proj = ScaledLinear(embed_dim, embed_dim, - bias=False, initial_scale=0.1) + bias=False, initial_scale=0.05) self.out_whiten = Whiten(num_groups=1, whitening_limit=10.0, @@ -1334,12 +1355,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim) # -> (num_heads, batch_size, seq_len, head_dim) bottleneck = torch.matmul(attn_weights, bottleneck) - bottleneck = self.balancer(bottleneck) - bottleneck = self.activation(bottleneck) + bottleneck = self.bottleneck_balancer(bottleneck) bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim) bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim) scales = self.from_bottleneck_proj(bottleneck) + scales = self.scale_balancer(scales) + x = self.activation_balancer(x) + x = self.in_proj(x) x = x * scales x = self.out_proj(x) @@ -1398,7 +1421,7 @@ class NonlinAttentionModule(nn.Module): # deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule self.deriv_balancer = ActivationBalancer( - channels, channel_dim=1, + channels, channel_dim=-1, min_positive=0.05, max_positive=1.0, max_abs=20.0, ) @@ -1439,8 +1462,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) # now x: (num_heads, batch_size, seq_len, head_dim) x = torch.matmul(attn_weights, x) # now x: (num_heads, batch_size, seq_len, head_dim) - x = self.deriv_balancer(x) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + x = self.deriv_balancer(x) x = self.activation(x) x = self.out_proj(x)