From 91840faa97c32f3d2e11e9d198555c67eb35beb3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Oct 2022 15:27:05 +0800 Subject: [PATCH] Implement whitening of values in conformer. --- .../pruned_transducer_stateless7/conformer.py | 194 +++++++++--------- 1 file changed, 102 insertions(+), 92 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index a944597b0..7e18b7470 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -801,124 +801,126 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) -class EntropyPenaltyFunction(torch.autograd.Function): +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, ::dim+1] + assert x.shape == (batch, dim) + return x + +class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod def forward(ctx, - attn_weights: Tensor, - num_heads: int, - entropy_limit: float, + x: Tensor, + whitening_limit: float, grad_scale: float) -> Tensor: - ctx.save_for_backward(attn_weights) - ctx.num_heads = num_heads - ctx.entropy_limit = entropy_limit + ctx.save_for_backward(x) + ctx.whitening_limit = whitening_limit ctx.grad_scale = grad_scale - return attn_weights + return x @staticmethod def backward(ctx, - attn_weights_grad: Tensor): - attn_weights, = ctx.saved_tensors - num_heads = ctx.num_heads - entropy_limit = ctx.entropy_limit - grad_scale = ctx.grad_scale + x_grad: Tensor): + x_orig, = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): - attn_weights_orig = attn_weights.to(torch.float32).detach() - attn_weights_orig.requires_grad = True - bsz = attn_weights_orig.shape[0] // num_heads - seq_len = attn_weights_orig.shape[2] - attn_weights = attn_weights_orig.reshape(bsz, num_heads, - seq_len, seq_len) + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + assert x_detached.ndim >= 3 + x = x_detached.reshape(-1, x_detached.shape[-2], + x_detached.shape[-1]).transpose(0, 1) + (num_groups, num_frames, channels_per_group) = x.shape - grad_norms = attn_weights_grad.detach().reshape( - bsz, num_heads, seq_len * seq_len).norm(dim=(0,2)) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + # normalize x_covar so that its average diagonal element is 1. + x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20) + # x_covar_sq: (num_groups, channels_per_group, channels_per_group). + # if the normalized x_covar were just `num_groups` copies of the + # identity matrix, x_covar_sq will have the same value. But + # in general, it will be larger than that. + x_covar_sq = torch.matmul(x_covar, x_covar) + + metric = _diag(x_covar_sq).mean() - entropy = ((attn_weights + 1.0e-20).log() * attn_weights).sum(dim=-1) - # entropy: (bsz, num_heads, seq_len) - entropy = -entropy.mean(dim=(0,2)) - # entropy: (num_heads,) - assert entropy.shape == (num_heads,) - excess_entropy = (entropy - entropy_limit).relu() - above_cutoff = (entropy > 0) # tensor of shape (num_heads,) - small_grad_norm = (grad_norms < grad_norms.mean()) - will_penalize = torch.logical_and(above_cutoff, small_grad_norm) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"entropy = {entropy}, entropy_limit={entropy_limit}, above_cutoff={above_cutoff}, small_grad_norm={small_grad_norm}, will_penalize={will_penalize}") - will_penalize_sum = will_penalize.to(torch.float32).sum().item() - if will_penalize_sum == 0: - # grad would be 0. I'm guessing that checking this, and - # incurring a CUDA sync, may save time relative to doing the - # backprop of the entropy, but I'm not sure. - return attn_weights_grad, None, None, None - # Treat `excess_entropy` as a loss, to be minimized. - excess_entropy.backward(gradient=will_penalize.to(torch.float32)) - entropy_grad = attn_weights_orig.grad - scale = ((grad_scale * will_penalize_sum / num_heads) * - (attn_weights_grad.to(torch.float32).norm() / - (entropy_grad.norm() + 1.0e-20))) - entropy_grad = entropy_grad * scale - return attn_weights_grad + entropy_grad.to(attn_weights_grad.dtype), None, None, None + logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * (x.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - -class EntropyPenalty(nn.Module): +class Whiten(nn.Module): def __init__( self, - num_heads: float, - entropy_delta: float, + whitening_limit: float, prob: float, grad_scale: float): """ Args: - num_heads: the number of attention heads in the self-attention module that - this is attached to. - entropy_delta: the delta from the maximum entropy, that we aim to - decrease the entropy to if it is above. So the maximum entropy - should be max(log(seq_len) - entropy_cutoff, 0.5 * log(seq_len)); - the second term is to make sure the limit never becomes tiny or - negative in the case of short sequences. - prob: the probability with which we apply this object. + num_groups: the number of groups to divide the input into before + whitening it. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply this object (also affects + grad scale). e.g. 0.25 grad_scale: determines the scale on the gradient term from this object, relative to the rest of the gradient on the attention weights; - will be divided by `prob`. + will be divided by `prob`. e.g. 0.005 """ - super(EntropyPenalty, self).__init__() - self.num_heads = num_heads - self.entropy_delta = entropy_delta + super(Whiten, self).__init__() + assert whitening_limit >= 1 + assert 0 < prob <= 1 + assert grad_scale >= 0 + self.whitening_limit = whitening_limit self.prob = prob self.grad_scale = grad_scale def forward(self, - attn_weights: Tensor) -> Tensor: + x: Tensor) -> Tensor: """ - In the forward pass, this function just returns the attention weights. + In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the - entropy of the attention heads is not too large. (We have noticed - that too-large/almost-maximal entropy in the attention distribution - is associated with heads that are not doing anything useful. + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. Args: - attn_weights: the attention weights, after the log, with shape - (batch_size * num_heads, seq_len, seq_len), satisfying: - attn_weights.sum(dim=-1) == 1. + x: the input of shape (*, num_groups, channels_per_group) + Returns: - the attn_weights, without any change. You should make sure - you use the returned attention weights, or the graph will be freed - and nothing will happen in backprop. + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. """ - if not attn_weights.requires_grad or random.random() > self.prob: - return attn_weights + if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: + return x else: - seq_len = attn_weights.shape[2] - max_entropy = math.log(seq_len) - entropy_limit = max(max_entropy - self.entropy_delta, - 0.5 * max_entropy) - return EntropyPenaltyFunction.apply(attn_weights, - self.num_heads, - entropy_limit, - self.grad_scale / self.prob) + return WhiteningPenaltyFunction.apply(x, + self.whitening_limit, + self.grad_scale / self.prob) @@ -955,6 +957,13 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim//2 must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) + + # self.whiten is applied on the values in forward() + self.whiten = Whiten(whitening_limit=2.0, + prob=1.0 if __name__ == "__main__" else 0.1, + grad_scale=0.0025) + + self.in_balancer = ActivationBalancer(3 * embed_dim // 2, channel_dim=-1, max_abs=5.0) self.in_max_eig = MaxEig(3 * embed_dim // 2, @@ -966,17 +975,15 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False) self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) + # self.whiten is applied on the values in forward2() + self.whiten2 = Whiten(whitening_limit=2.0, + prob=1.0 if __name__ == "__main__" else 0.1, + grad_scale=0.0025) + # linear transformation for positional encoding (projects to a scalar per head, # which will be added to the score). self.linear_pos = ScaledLinear(embed_dim, num_heads, initial_scale=0.05) - self.entropy_penalty = EntropyPenalty(num_heads, - entropy_delta=1.5, - prob=1.0 if __name__ == "__main__" else 0.2, - grad_scale=0.01) - - self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads)) - self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads)) # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, num_heads, bias=False) @@ -1196,7 +1203,9 @@ class RelPositionMultiheadAttention(nn.Module): q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz, num_heads, head_dim) + v = self.whiten(v) # does nothing in the forward pass. + v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1) if key_padding_mask is not None: @@ -1278,14 +1287,15 @@ class RelPositionMultiheadAttention(nn.Module): Returns: output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) """ - attn_weights = self.entropy_penalty(attn_weights) - num_heads = self.num_heads (seq_len, bsz, embed_dim) = x.shape head_dim = embed_dim // (num_heads * 2) # v: (tgt_len, bsz, embed_dim // 2) v = self.in_proj2(x) + v = v.contiguous().view(-1, bsz, num_heads, head_dim) + v = self.whiten2(v) # does nothing in the forward pass. v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) + # now v: (bsz * num_heads, seq_len, head_dim) attn_output = torch.bmm(attn_weights, v)