From fc728f2738080d1e93ab41742f9b0dcfaba49d10 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 15 Oct 2022 23:20:18 +0800 Subject: [PATCH] Reorganize Whiten() code; configs are not the same as before. Also remove MaxEig for self_attn module --- .../pruned_transducer_stateless7/conformer.py | 162 ++------------- .../pruned_transducer_stateless7/scaling.py | 196 ++++++++++++++++++ 2 files changed, 215 insertions(+), 143 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 0093b644c..77bbf4625 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -31,6 +31,8 @@ from scaling import ( DoubleSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + _diag, ) from torch import Tensor, nn @@ -801,129 +803,6 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] - assert x.shape == (batch, dim) - return x - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, - x: Tensor, - whitening_limit: float, - grad_scale: float) -> Tensor: - ctx.save_for_backward(x) - ctx.whitening_limit = whitening_limit - ctx.grad_scale = grad_scale - return x - - @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - assert x_detached.ndim >= 3 - x = x_detached.reshape(-1, x_detached.shape[-2], - x_detached.shape[-1]).transpose(0, 1) - (num_groups, num_frames, channels_per_group) = x.shape - - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - # normalize x_covar so that its average diagonal element is 1. - x_covar = x_covar / (_diag(x_covar).mean() + 1.0e-20) - # x_covar_sq: (num_groups, channels_per_group, channels_per_group). - # if the normalized x_covar were just `num_groups` copies of the - # identity matrix, x_covar_sq will have the same value. But - # in general, it will be larger than that. - x_covar_sq = torch.matmul(x_covar, x_covar) - - metric = _diag(x_covar_sq).mean() - - if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: num_groups={num_groups}, channels_per_group={channels_per_group}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") - - (metric - ctx.whitening_limit).relu().backward() - penalty_grad = x_detached.grad - scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - - - - -class Whiten(nn.Module): - def __init__( - self, - whitening_limit: float, - prob: float, - grad_scale: float): - """ - Args: - num_groups: the number of groups to divide the input into before - whitening it. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply this object (also affects - grad scale). e.g. 0.25 - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights; - will be divided by `prob`. e.g. 0.005 - """ - super(Whiten, self).__init__() - assert whitening_limit >= 1 - assert 0 < prob <= 1 - assert grad_scale >= 0 - self.whitening_limit = whitening_limit - self.prob = prob - self.grad_scale = grad_scale - - def forward(self, - x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_groups, channels_per_group) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: - return x - else: - return WhiteningPenaltyFunction.apply(x, - self.whitening_limit, - self.grad_scale / self.prob) - - - class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -958,20 +837,20 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) - # self.whiten is applied on the values in forward() - self.whiten_values = Whiten(whitening_limit=1.1, - prob=1.0 if __name__ == "__main__" else 0.1, - grad_scale=0.0025) + # self.whiten_values is applied on the values in forward() + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=1.1, + prob=(0.025, 0.25), + grad_scale=0.025) # self.whiten_keys is applied on the keys in forward() - self.whiten_keys = Whiten(whitening_limit=1.1, - prob=1.0 if __name__ == "__main__" else 0.1, - grad_scale=0.0025) + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=1.1, + prob=(0.025, 0.25), + grad_scale=0.025) self.in_balancer = ActivationBalancer(3 * embed_dim // 2, channel_dim=-1, max_abs=5.0) - self.in_max_eig = MaxEig(3 * embed_dim // 2, - channel_dim=-1) self.out_proj = ScaledLinear( embed_dim // 2, embed_dim, bias=True, initial_scale=0.05 ) @@ -980,10 +859,10 @@ class RelPositionMultiheadAttention(nn.Module): self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True, initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten(whitening_limit=1.1, - prob=1.0 if __name__ == "__main__" else 0.1, - grad_scale=0.0025) - + self.whiten_values2 = Whiten(num_groups=num_heads, + whitening_limit=1.1, + prob=(0.025, 0.25), + grad_scale=0.025) # linear transformation for positional encoding (projects to a scalar per head, # which will be added to the score). @@ -1037,7 +916,7 @@ class RelPositionMultiheadAttention(nn.Module): and S is the sequence length. """ x, weights = self.multi_head_attention_forward( - self.in_max_eig(self.in_balancer(self.in_proj(x))), + self.in_balancer(self.in_proj(x)), self.linear_pos(pos_emb), self.embed_dim, self.num_heads, @@ -1155,6 +1034,8 @@ class RelPositionMultiheadAttention(nn.Module): # self-attention q, k, v = x.chunk(3, dim=-1) + k = self.whiten_keys(k) # does nothing in the forward pass. + v = self.whiten_values(v) # does nothing in the forward pass. if attn_mask is not None: assert ( @@ -1207,11 +1088,7 @@ class RelPositionMultiheadAttention(nn.Module): q = (q * scaling).contiguous().view(seq_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) - k = self.whiten_keys(k) # does nothing in the forward pass. - v = v.contiguous().view(-1, bsz, num_heads, head_dim) - v = self.whiten_values(v) # does nothing in the forward pass. - v = v.view(-1, bsz * num_heads, head_dim).transpose(0, 1) - + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( @@ -1297,7 +1174,6 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = embed_dim // (num_heads * 2) # v: (tgt_len, bsz, embed_dim // 2) v = self.in_proj2(x) - v = v.contiguous().view(-1, bsz, num_heads, head_dim) v = self.whiten_values2(v) # does nothing in the forward pass. v = v.contiguous().view(seq_len, bsz * num_heads, head_dim).transpose(0, 1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index cdbd781f1..fe8867291 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -421,6 +421,172 @@ class ActivationBalancer(torch.nn.Module): return x + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, ::dim+1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, + num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float) -> Tensor: + ctx.save_for_backward(x) + ctx.num_groups = num_groups + ctx.whitening_limit = whitening_limit + ctx.grad_scale = grad_scale + return x + + @staticmethod + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, ctx.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float,float]], + grad_scale: float): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert whitening_limit >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + if isinstance(prob, float): + assert 0 < prob <= 1 + self.prob = prob + else: + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob < self.max_prob <= 1 + self.prob = self.max_prob + + self.grad_scale = grad_scale + + def forward(self, + x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: + return x + else: + if hasattr(self, 'min_prob') and random.random() < 0.25: + # occasionally switch between min_prob and max_prob, based on whether + # we are above or below the threshold. + if _whitening_metric(x, self.num_groups) > self.whitening_limit: + # there would be a change to the grad. + self.prob = self.max_prob + else: + self.prob = self.min_prob + + return WhiteningPenaltyFunction.apply(x, + self.num_groups, + self.whitening_limit, + self.grad_scale) + + + + + class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -632,6 +798,35 @@ def _test_max_eig(): assert not torch.allclose(x.grad, y_grad) +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) @@ -714,6 +909,7 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_whiten() _test_max_eig() _test_activation_balancer_sign() _test_activation_balancer_magnitude()