diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index db341a1c9..23f0b281b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -421,7 +421,7 @@ class LinearWithAuxLossFunction(torch.autograd.Function): # recompute y as we need the gradient; this is easier to implement than # saving y in the context. y = torch.matmul(x, weight.t()) - z = alpha * torch.matmul(y, weight) + z = alpha.exp() * torch.matmul(y, weight) diff = x - z dims_to_mean = tuple(range(x.ndim-1)) mean = diff.mean(dim=dims_to_mean) @@ -452,7 +452,7 @@ class LinearWithAuxLoss(nn.Module): Suppose the input is x, and this layer computes: y = M x (the bias is applied separately), then we define: - z = alpha * M^T y + z = exp(alpha) * M^T y where alpha is learnable; and the auxiliary loss will be: aux_loss = normalize_mean(z - x)^2. (normalize_mean refers to subtracting the average value per channel, @@ -485,7 +485,7 @@ class LinearWithAuxLoss(nn.Module): 0.01 * initial_scale) else: self.register_parameter('bias', None) - self.alpha = nn.Parameter(torch.tensor(1.0)) + self.alpha = nn.Parameter(torch.tensor(0.0)) def forward(self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 76493678e..fe19ed8aa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -343,7 +343,7 @@ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: default=x) def _aux_grad_scale() -> float: - return 0.1 + return 0.2 def _aux_grad_prob() -> ScheduledFloat: return ScheduledFloat((0.0, 0.25), (1000.0, 0.0125)) @@ -1432,33 +1432,27 @@ class NonlinAttentionModule(nn.Module): ) -> None: super().__init__() - self.in_proj = LinearWithAuxLoss(channels, 2 * channels, bias=True, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob()) + self.in_proj = nn.Linear(channels, 2 * channels, bias=True) - # balancer that goes after the glu mechanism. + # balancer that goes before the sigmoid. self.balancer = ActivationBalancer( channels, channel_dim=-1, - min_positive=0.2, max_positive=0.8, - min_abs=0.2, max_abs=10.0, - min_prob=0.05, + min_positive=0.05, max_positive=1.0, + min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0), + (4000.0, 10.0), + default=1.0), ) - self.pre_sigmoid = Identity() # for diagnostics. self.sigmoid = nn.Sigmoid() self.activation = Identity() # for diagnostics. - self.out_proj = LinearWithAuxLoss(channels, channels, - bias=True, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob(), - initial_scale=0.05) + self.out_proj = ScaledLinear(channels, channels, + bias=True, + initial_scale=0.05) - self.whiten1 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01) - self.whiten2 = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) @@ -1475,19 +1469,12 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) """ x = self.in_proj(x) - v, s = x.chunk(2, dim=-1) + x, s = x.chunk(2, dim=-1) - if self.training and random.random() < 0.02: - # prevent the inputs to the sigmoid from getting very large (this is - # hopefully quite a rare phenomenon, so we are giving this path a - # very small probability to save time). - s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04) + s = self.balancer(s) + s = self.sigmoid(s) - v = self.whiten1(v) - # GLU mechanism - s = self.pre_sigmoid(s) - x = self.sigmoid(s) * v - x = self.balancer(x) + x = x * s (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0] @@ -1500,7 +1487,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = self.activation(x) # diagnostics only, it's the identity. - x = self.whiten2(x) + x = self.whiten(x) x = self.out_proj(x) return x