From 48d699c94bd57bf1461f3971465f091ba15c2db4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 18:42:03 +0800 Subject: [PATCH 1/2] Change for speed/memory --- .../ASR/pruned_transducer_stateless7/scaling.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index eb556b4e9..db341a1c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -396,6 +396,8 @@ class LinearWithAuxLossFunction(torch.autograd.Function): In the backward pass it will include an auxiliary loss based on predicting x from matmul(y, weight). """ + if torch.is_autocast_enabled(): + x = x.to(torch.float16) ctx.save_for_backward(x, weight, alpha) ctx.aux_grad_scale = aux_grad_scale return torch.matmul(x, weight.t()) @@ -491,10 +493,14 @@ class LinearWithAuxLoss(nn.Module): aux_grad_scale = float(self.aux_grad_scale) if (not self.training or torch.jit.is_scripting() or aux_grad_scale == 0.0 or random.random() > float(self.prob)): - return torch.matmul(x, self.weight.t()) + self.bias + return torch.nn.functional.linear(x, self.weight, self.bias) else: - return LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, - aux_grad_scale) + self.bias + ans = LinearWithAuxLossFunction.apply(x, self.weight, self.alpha, + aux_grad_scale) + if self.bias is not None: + ans += self.bias + return ans + def ScaledLinear(*args, initial_scale: float = 1.0, From faed28ba6a2a0db5e2ec0567c69bdd577ec17e88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Nov 2022 18:59:15 +0800 Subject: [PATCH 2/2] Changes for debugging/stats. --- .../ASR/pruned_transducer_stateless7/zipformer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 6a4f08f6d..f1f548ac0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1434,6 +1434,7 @@ class NonlinAttentionModule(nn.Module): min_abs=0.2, max_abs=10.0, min_prob=0.05, ) + self.pre_sigmoid = Identity() # for diagnostics. self.sigmoid = nn.Sigmoid() self.activation = Identity() # for diagnostics. @@ -1475,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) v = self.whiten1(v) # GLU mechanism + s = self.pre_sigmoid(s) x = self.sigmoid(s) * v x = self.balancer(x) @@ -1537,6 +1539,9 @@ class ConvolutionModule(nn.Module): max_abs=10.0, min_positive=0.05, max_positive=1.0 ) + self.pre_sigmoid = Identity() # before sigmoid; for diagnostics. + self.sigmoid = nn.Sigmoid() + self.depthwise_conv = nn.Conv1d( channels, channels, @@ -1585,7 +1590,13 @@ class ConvolutionModule(nn.Module): x = self.in_proj(x) # (time, batch, 2*channels) x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=-1) # (time, batch, channels) + + x, s = x.chunk(2, dim=-1) + s = self.pre_sigmoid(s) + s = self.sigmoid(s) + x = x * s + + # (time, batch, channels) # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time).