From 0f567e27a5a9d53450e4c512954bbb7f6ab49b49 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 18 Sep 2022 21:22:01 +0800 Subject: [PATCH 1/7] Add max_var_per_eig in self-attn --- .../pruned_transducer_stateless7/conformer.py | 28 ++- .../pruned_transducer_stateless7/scaling.py | 233 ++++++++++++++++++ 2 files changed, 252 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 57302b0cd..e98ff46ee 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -173,7 +173,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1, max_abs=10.0), + ActivationBalancer(dim_feedforward, + channel_dim=-1, max_abs=10.0), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, @@ -182,7 +183,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1, max_abs=10.0), + ActivationBalancer(dim_feedforward, + channel_dim=-1, max_abs=10.0), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, @@ -196,7 +198,7 @@ class ConformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) self.dropout = nn.Dropout(dropout) @@ -464,8 +466,11 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) - self.proj_balancer = ActivationBalancer(channel_dim=-1, max_abs=10.0, + self.in_balancer = ActivationBalancer(3 * embed_dim, + channel_dim=-1, max_abs=5.0, + max_var_per_eig=0.1) + self.proj_balancer = ActivationBalancer(embed_dim, + channel_dim=-1, max_abs=10.0, min_positive=0.0, max_positive=1.0) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.5 @@ -901,6 +906,7 @@ class ConvolutionModule(nn.Module): # it will be in a better position to start learning something, i.e. to latch onto # the correct range. self.deriv_balancer1 = ActivationBalancer( + 2 * channels, channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) @@ -915,7 +921,7 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channel_dim=1, min_positive=0.05, max_positive=1.0 + channels, channel_dim=1, min_positive=0.05, max_positive=1.0 ) self.activation = DoubleSwish() @@ -1001,7 +1007,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=1, ), - ActivationBalancer(channel_dim=1), + ActivationBalancer(layer1_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1009,7 +1016,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, stride=2, ), - ActivationBalancer(channel_dim=1), + ActivationBalancer(layer2_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, @@ -1017,7 +1025,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, stride=2, ), - ActivationBalancer(channel_dim=1), + ActivationBalancer(layer3_channels, + channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2 - 1) // 2) @@ -1028,6 +1037,7 @@ class Conv2dSubsampling(nn.Module): self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer( + out_channels, channel_dim=-1, min_positive=0.45, max_positive=0.55 ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index e93f41718..601426318 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -114,6 +114,173 @@ class ActivationBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, None, None, None +def find_direction_coeffs(x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + + assert prev_direction.shape == (num_channels,) + + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + + return cur_direction, coeffs + + +def get_max_eig_proportion(x: Tensor, + prev_direction: Tensor, + subtract_mean: bool) -> Tuple[Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (*, num_channels). There must be more than one frame, + i.e. x.numel() // num_channels > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Expected to be without gradient. Does not have to be + normalized. + subtract_mean: if True, we will first subtract the mean of x, over the + frames. Suggest to make this true in most circumstances. + + Returns: (cur_direction, max_proportion), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. Detached / not intended to be + differentiable. + proportion: a scalar Tensor containing the proportion of the variance + of the input that is in direction `cur_direction`. This is with + gradient, that can be propagated back to x. + """ + num_channels = x.shape[-1] + assert prev_direction.shape == (num_channels,) + x = x.reshape(-1, num_channels) + if subtract_mean: + x = x - x.mean(dim=0) + + with torch.no_grad(): + cur_norm = prev_direction.norm() + + prev_direction = prev_direction / cur_norm + is_ok = (cur_norm / cur_norm == 1.0) + # if there was a problem like NaN or inf, restart. this should be very rare. + prev_direction = torch.where(is_ok.unsqueeze(-1).expand(prev_direction.shape), + prev_direction, + torch.randn_like(prev_direction) * (num_channels ** -0.5)) + + # `coeffs` are the coefficients of `prev_direction` in x. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + + x_norm = x.norm() + x_coeffs1_norm = (x - coeffs * prev_direction).norm() + + with torch.no_grad(): + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + + x_coeffs2_norm = (x - coeffs * cur_direction).norm() + + # for the returned direction interpolate with prev_direction so that + # even if x == 0, we get a nonzero new direction. + ans_direction = 0.5 * (prev_direction + cur_direction) + + x_sumsq = (x**2).sum() + 1.0e-20 + x_remaining_sumsq = ((x - coeffs * cur_direction) ** 2).sum() + 1.0e-20 + + proportion = (x - x_remaining_sumsq) / x_sumsq + + return (ans_direction, proportion) + + print(f"x_norm={x_norm}, x_coeffs1_norm={x_coeffs1_norm}, x_coeffs2_norm={x_coeffs2_norm}") + + + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + direction: Tensor, + channel_dim: int, + prob: float, + subtract_mean: bool, + max_variance_proportion: float, + grad_scale: float) -> Tuple[Tensor, Tensor]: + if random.random() > prob: + return x, direction + eps = 1.0e-20 + num_channels = x.shape[channel_dim] + assert max_variance_proportion > 1.0 / num_channels + orig_x = x + x = x.transpose(channel_dim, -1).reshape(-1, num_channels) + if subtract_mean: + x = x - x.mean(dim=0) + new_direction, coeffs = find_direction_coeffs(x, direction) + x_var = (x**2).sum() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).sum() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / x_var + + ans_direction = direction + new_direction # ensure nonzero even if x == 0 + ans_direction = ans_direction / ans_direction.norm() + + logging.info(f"variance_proportion = {variance_proportion.item()}") + + # Caution: this causes a CUDA sync, which is not ideal. + if variance_proportion >= max_variance_proportion: + ctx.channel_dim = channel_dim + ctx.subtract_mean = subtract_mean + ctx.grad_scale = grad_scale + ctx.save_for_backward(orig_x.detach(), + coeffs.detach(), + new_direction.detach()) + + return orig_x, ans_direction + + @staticmethod + def backward(ctx, x_grad, *args): + # the *args is all the other derivs, which should be None or zero. + if not hasattr(ctx, 'channel_dim'): + # the top eig's proportion of the variance was below the threshold. + return x_grad, None, None, None, None, None, None + + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + if ctx.subtract_mean: + x = x - x.mean(dim=0) + x_var = (x**2).sum() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).sum() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / x_var + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = x_orig.grad * x_orig.grad.norm() / (x_orig_grad.norm() + 1.0e-20) + return x_grad + x_extra_grad, None, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -236,6 +403,7 @@ class ActivationBalancer(torch.nn.Module): Args: + num_channels: the number of channels channel_dim: the dimension/axis corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. min_positive: the minimum, per channel, of the proportion of the time @@ -252,29 +420,56 @@ class ActivationBalancer(torch.nn.Module): max_abs: the maximum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. """ def __init__( self, + num_channels: int, channel_dim: int, min_positive: float = 0.05, max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0, + max_var_per_eig: float = 0.0, ): super(ActivationBalancer, self).__init__() + self.num_channels = num_channels self.channel_dim = channel_dim self.min_positive = min_positive self.max_positive = max_positive self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + if max_var_per_eig > 0.0: + with torch.no_grad(): + direction = torch.randn(num_channels) + direction = direction / direction.norm() + self.register_buffer('max_eig_direction', direction) + else: + self.max_eig_direction = None + def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting(): return x + if self.max_var_per_eig > 0: + x, new_direction = MaxEigLimiterFunction.apply( + x, self.max_eig_direction, + self.channel_dim, + 0.1, # prob + True, # subtract_mean + self.max_var_per_eig, + self.max_factor, + ) + self.max_eig_direction[:] = new_direction + return ActivationBalancerFunction.apply( x, self.channel_dim, @@ -326,6 +521,35 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) +def _test_max_eig_limiter(): + + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + y, new_direction = MaxEigLimiterFunction.apply(x, direction, + 1, # channel_dim + 1.0, # prob + True, # subtract_mean + 0.5, # max_variance_proportion + 0.1, # grad_scale + ) + + cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm()) + logging.info(f"Direction cosine = {cosine}") + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) @@ -336,6 +560,7 @@ def _test_activation_balancer_sign(): x = x.detach() x.requires_grad = True m = ActivationBalancer( + probs.numel(), channel_dim=0, min_positive=0.05, max_positive=0.98, @@ -361,6 +586,7 @@ def _test_activation_balancer_magnitude(): x = x.detach() x.requires_grad = True m = ActivationBalancer( + magnitudes.numel(), channel_dim=0, min_positive=0.0, max_positive=1.0, @@ -402,10 +628,17 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) +def _test_get_max_eig_proportion(): + x = torch.randn(100, 128) + d = torch.randn(128) * (128 ** -0.5) + get_max_eig_proportion(x, d, True) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_max_eig_limiter() + _test_get_max_eig_proportion() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() From 3d72a65de850d90101131c0ea564ae829af8cb75 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 19 Sep 2022 10:26:37 +0800 Subject: [PATCH 2/7] Implement max-eig-proportion.. --- .../pruned_transducer_stateless7/conformer.py | 92 ++------------ .../ASR/pruned_transducer_stateless7/optim.py | 2 +- .../pruned_transducer_stateless7/scaling.py | 115 ++++-------------- 3 files changed, 38 insertions(+), 171 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index e98ff46ee..77b786a91 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -249,8 +249,6 @@ class ConformerEncoderLayer(nn.Module): # multi-headed self-attention module src_att = self.self_attn( - src, - src, src, pos_emb=pos_emb, attn_mask=src_mask, @@ -490,9 +488,7 @@ class RelPositionMultiheadAttention(nn.Module): def forward( self, - query: Tensor, - key: Tensor, - value: Tensor, + x: Tensor, pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, @@ -500,7 +496,7 @@ class RelPositionMultiheadAttention(nn.Module): ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: - query, key, value: map a query and a set of key-value pairs to an output. + x: input to be projected to query, key, value pos_emb: Positional embedding tensor key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, @@ -513,11 +509,7 @@ class RelPositionMultiheadAttention(nn.Module): Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. @@ -540,9 +532,7 @@ class RelPositionMultiheadAttention(nn.Module): L is the target sequence length, S is the source sequence length. """ return self.multi_head_attention_forward( - query, - key, - value, + self.in_balancer(self.in_proj(x)), pos_emb, self.embed_dim, self.num_heads, @@ -584,11 +574,9 @@ class RelPositionMultiheadAttention(nn.Module): def multi_head_attention_forward( self, - query: Tensor, - key: Tensor, - value: Tensor, + x: Tensor, pos_emb: Tensor, - embed_dim_to_check: int, + embed_dim: int, num_heads: int, in_proj_weight: Tensor, in_proj_bias: Tensor, @@ -604,7 +592,7 @@ class RelPositionMultiheadAttention(nn.Module): Args: query, key, value: map a query and a set of key-value pairs to an output. pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. + embed_dim: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. dropout_p: probability of an element to be zeroed. @@ -646,9 +634,7 @@ class RelPositionMultiheadAttention(nn.Module): L is the target sequence length, S is the source sequence length. """ - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + tgt_len, bsz, _ = x.size() head_dim = embed_dim // num_heads assert ( @@ -657,62 +643,10 @@ class RelPositionMultiheadAttention(nn.Module): scaling = float(head_dim) ** -0.5 - def linear(x, w, b): - return self.in_balancer(nn.functional.linear(x, w, b)) - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + # self-attention + q, k, v = x.chunk(3, dim=-1) - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = linear(value, _w, _b) if attn_mask is not None: assert ( @@ -732,15 +666,15 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + if list(attn_mask.size()) != [1, tgt_len, tgt_len]: raise RuntimeError( "The size of the 2D attn_mask is not correct." ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, - query.size(0), - key.size(0), + tgt_len, + tgt_len, ]: raise RuntimeError( "The size of the 3D attn_mask is not correct." diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 257312f9a..147b98a8f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -254,7 +254,7 @@ class ScaledAdam(Optimizer): if ans < 1.0: state["num_clipped"] += 1 if ans < 0.1: - logging.warn("Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") + logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 601426318..3fe71698b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -145,71 +145,6 @@ def find_direction_coeffs(x: Tensor, return cur_direction, coeffs -def get_max_eig_proportion(x: Tensor, - prev_direction: Tensor, - subtract_mean: bool) -> Tuple[Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (*, num_channels). There must be more than one frame, - i.e. x.numel() // num_channels > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Expected to be without gradient. Does not have to be - normalized. - subtract_mean: if True, we will first subtract the mean of x, over the - frames. Suggest to make this true in most circumstances. - - Returns: (cur_direction, max_proportion), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. Detached / not intended to be - differentiable. - proportion: a scalar Tensor containing the proportion of the variance - of the input that is in direction `cur_direction`. This is with - gradient, that can be propagated back to x. - """ - num_channels = x.shape[-1] - assert prev_direction.shape == (num_channels,) - x = x.reshape(-1, num_channels) - if subtract_mean: - x = x - x.mean(dim=0) - - with torch.no_grad(): - cur_norm = prev_direction.norm() - - prev_direction = prev_direction / cur_norm - is_ok = (cur_norm / cur_norm == 1.0) - # if there was a problem like NaN or inf, restart. this should be very rare. - prev_direction = torch.where(is_ok.unsqueeze(-1).expand(prev_direction.shape), - prev_direction, - torch.randn_like(prev_direction) * (num_channels ** -0.5)) - - # `coeffs` are the coefficients of `prev_direction` in x. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) - - x_norm = x.norm() - x_coeffs1_norm = (x - coeffs * prev_direction).norm() - - with torch.no_grad(): - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) - - x_coeffs2_norm = (x - coeffs * cur_direction).norm() - - # for the returned direction interpolate with prev_direction so that - # even if x == 0, we get a nonzero new direction. - ans_direction = 0.5 * (prev_direction + cur_direction) - - x_sumsq = (x**2).sum() + 1.0e-20 - x_remaining_sumsq = ((x - coeffs * cur_direction) ** 2).sum() + 1.0e-20 - - proportion = (x - x_remaining_sumsq) / x_sumsq - - return (ans_direction, proportion) - - print(f"x_norm={x_norm}, x_coeffs1_norm={x_coeffs1_norm}, x_coeffs2_norm={x_coeffs2_norm}") - - class MaxEigLimiterFunction(torch.autograd.Function): @@ -233,17 +168,18 @@ class MaxEigLimiterFunction(torch.autograd.Function): if subtract_mean: x = x - x.mean(dim=0) new_direction, coeffs = find_direction_coeffs(x, direction) - x_var = (x**2).sum() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).sum() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / x_var + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) ans_direction = direction + new_direction # ensure nonzero even if x == 0 ans_direction = ans_direction / ans_direction.norm() - logging.info(f"variance_proportion = {variance_proportion.item()}") + if random.random() < 0.01: + logging.info(f"variance_proportion = {variance_proportion.item()}") # Caution: this causes a CUDA sync, which is not ideal. if variance_proportion >= max_variance_proportion: @@ -262,7 +198,6 @@ class MaxEigLimiterFunction(torch.autograd.Function): if not hasattr(ctx, 'channel_dim'): # the top eig's proportion of the variance was below the threshold. return x_grad, None, None, None, None, None, None - with torch.enable_grad(): (x_orig, coeffs, new_direction) = ctx.saved_tensors x_orig.requires_grad = True @@ -271,16 +206,16 @@ class MaxEigLimiterFunction(torch.autograd.Function): new_direction.requires_grad = False if ctx.subtract_mean: x = x - x.mean(dim=0) - x_var = (x**2).sum() + x_var = (x ** 2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).sum() + x_residual_var = (x_residual ** 2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / x_var + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * x_orig.grad.norm() / (x_orig_grad.norm() + 1.0e-20) - return x_grad + x_extra_grad, None, None, None, None, None, None + x_orig_grad = x_orig.grad + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + return x_grad + x_extra_grad.detach(), None, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -448,7 +383,9 @@ class ActivationBalancer(torch.nn.Module): self.max_var_per_eig = max_var_per_eig if max_var_per_eig > 0.0: with torch.no_grad(): - direction = torch.randn(num_channels) + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() self.register_buffer('max_eig_direction', direction) else: @@ -460,15 +397,16 @@ class ActivationBalancer(torch.nn.Module): return x if self.max_var_per_eig > 0: - x, new_direction = MaxEigLimiterFunction.apply( - x, self.max_eig_direction, - self.channel_dim, - 0.1, # prob - True, # subtract_mean - self.max_var_per_eig, - self.max_factor, - ) - self.max_eig_direction[:] = new_direction + with torch.cuda.amp.autocast(enabled=False): + x, new_direction = MaxEigLimiterFunction.apply( + x, self.max_eig_direction, + self.channel_dim, + 0.25, # prob + True, # subtract_mean + self.max_var_per_eig, + self.max_factor, + ) + self.max_eig_direction[:] = new_direction.detach() return ActivationBalancerFunction.apply( x, @@ -628,17 +566,12 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -def _test_get_max_eig_proportion(): - x = torch.randn(100, 128) - d = torch.randn(128) * (128 ** -0.5) - get_max_eig_proportion(x, d, True) if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_max_eig_limiter() - _test_get_max_eig_proportion() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() From db1f4ccdd195d088313ef3d80a05703e368dc724 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Sep 2022 14:20:13 +0800 Subject: [PATCH 3/7] 4x scale on max-eig constraint --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 3fe71698b..96906b726 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -397,14 +397,15 @@ class ActivationBalancer(torch.nn.Module): return x if self.max_var_per_eig > 0: + max_eig_prob = 0.25 with torch.cuda.amp.autocast(enabled=False): x, new_direction = MaxEigLimiterFunction.apply( x, self.max_eig_direction, self.channel_dim, - 0.25, # prob + max_eig_prob, True, # subtract_mean self.max_var_per_eig, - self.max_factor, + self.max_factor / max_eig_prob, ) self.max_eig_direction[:] = new_direction.detach() From cd5ac76a053e3e575c8af501c9071a37cc610a50 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Sep 2022 14:22:07 +0800 Subject: [PATCH 4/7] Add max-var-per-eig in encoder layers --- .../ASR/pruned_transducer_stateless7/conformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 77b786a91..7d785a369 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -198,7 +198,10 @@ class ConformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + d_model, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + max_abs=6.0, + max_var_per_eig=0.1, ) self.dropout = nn.Dropout(dropout) From 6eb9a0bc9bd62307a73fc52a52147f282f540b31 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Sep 2022 14:39:17 +0800 Subject: [PATCH 5/7] Halve max_var_per_eig to 0.05 --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 7d785a369..328cb4434 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -201,7 +201,7 @@ class ConformerEncoderLayer(nn.Module): d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0, - max_var_per_eig=0.1, + max_var_per_eig=0.05, ) self.dropout = nn.Dropout(dropout) @@ -469,7 +469,7 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_balancer = ActivationBalancer(3 * embed_dim, channel_dim=-1, max_abs=5.0, - max_var_per_eig=0.1) + max_var_per_eig=0.05) self.proj_balancer = ActivationBalancer(embed_dim, channel_dim=-1, max_abs=10.0, min_positive=0.0, max_positive=1.0) From 1d20c12bc02f544de828906956b4d29064673bc8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Sep 2022 12:28:35 +0800 Subject: [PATCH 6/7] Increase max_var_per_eig to 0.2 --- egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 328cb4434..182b78eee 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -201,7 +201,7 @@ class ConformerEncoderLayer(nn.Module): d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0, - max_var_per_eig=0.05, + max_var_per_eig=0.2, ) self.dropout = nn.Dropout(dropout) @@ -469,7 +469,7 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_balancer = ActivationBalancer(3 * embed_dim, channel_dim=-1, max_abs=5.0, - max_var_per_eig=0.05) + max_var_per_eig=0.2) self.proj_balancer = ActivationBalancer(embed_dim, channel_dim=-1, max_abs=10.0, min_positive=0.0, max_positive=1.0) From ceadfad48dc2e8cd2186a39591e8c69804d322de Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 22 Sep 2022 12:30:49 +0800 Subject: [PATCH 7/7] Reduce debug freq --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 96906b726..374a260e5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -178,7 +178,7 @@ class MaxEigLimiterFunction(torch.autograd.Function): ans_direction = direction + new_direction # ensure nonzero even if x == 0 ans_direction = ans_direction / ans_direction.norm() - if random.random() < 0.01: + if random.random() < 0.001: logging.info(f"variance_proportion = {variance_proportion.item()}") # Caution: this causes a CUDA sync, which is not ideal.