From 12323025d731d701186c419e47f2acd2a166d218 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 12 Oct 2022 18:44:52 +0800 Subject: [PATCH] Make ActivationBalancer and MaxEig more efficient. --- .../pruned_transducer_stateless7/conformer.py | 15 +- .../pruned_transducer_stateless7/scaling.py | 491 +++++++++++------- 2 files changed, 309 insertions(+), 197 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index cef8d1b18..185f7c98d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -27,6 +27,7 @@ from encoder_interface import EncoderInterface from scaling import ( ActivationBalancer, BasicNorm, + MaxEig, DoubleSwish, ScaledConv1d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. @@ -293,8 +294,11 @@ class ConformerEncoderLayer(nn.Module): d_model, channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0, - max_var_per_eig=0.2, ) + self.max_eig = MaxEig( + d_model, channel_dim=-1, + ) + def forward( self, @@ -350,7 +354,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.feed_forward3(src) - src = self.norm_final(self.balancer(src)) + src = self.norm_final(self.max_eig(self.balancer(src))) delta = src - src_orig bypass_scale = self.bypass_scale @@ -838,8 +842,9 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) self.in_balancer = ActivationBalancer(3 * embed_dim // 2, - channel_dim=-1, max_abs=5.0, - max_var_per_eig=0.2) + channel_dim=-1, max_abs=5.0) + self.in_max_eig = MaxEig(3 * embed_dim // 2, + channel_dim=-1) self.proj_balancer = ActivationBalancer(embed_dim // 2, channel_dim=-1, max_abs=10.0, min_positive=0.0, max_positive=1.0) @@ -915,7 +920,7 @@ class RelPositionMultiheadAttention(nn.Module): before softmax. """ x, weights, scores = self.multi_head_attention_forward( - self.in_balancer(self.in_proj(x)), + self.in_max_eig(self.in_balancer(self.in_proj(x))), pos_emb, None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in), self.embed_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 24ddf892f..5d63137ff 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -29,120 +29,35 @@ from torch import Tensor from torch.nn import Embedding as ScaledEmbedding -def _ntuple(n): - def parse(x): - if isinstance(x, collections.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) - - class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - x_normalized = x - torch.mean(x, dim=sum_dims, keepdim=True) - xgtmean = (x_normalized > 0) - proportion_positive = torch.mean( - (x > 0).to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) - # `factor` is a tensor of shape something like (1, 1, num_channels, - # 1), containing elements between -1 and 1 that are zero if the - # proportion of positive features is between min_positive and - # max_positive, max_factor if proportion==0.0 (all features are negative), - # and -max_factor if proportion==1.0 (all features are positive). It is - # an amount per channel by which we'll modify the gradient; the sign - # of modifying the gradient will depend on the sign of the gradient. - - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x_normalized.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs - - ctx.save_for_backward( - factor, xgtmean, below_threshold, above_threshold - ) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = (x > 0) + ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x + @staticmethod def backward( ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgtmean, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgtmean.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) + ) -> Tuple[Tensor, None, None, None]: + xgt0, sign_factor, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -def find_direction_coeffs(x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ - (num_frames, num_channels) = x.shape - assert num_channels > 1 and num_frames > 1 - - assert prev_direction.shape == (num_channels,) - - # `coeffs` are the coefficients of `prev_direction` in x. - # actually represent the coeffs up to a constant positive factor. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) - - return cur_direction, coeffs + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return x_grad - neg_delta_grad, None, None, None, @@ -152,57 +67,27 @@ class MaxEigLimiterFunction(torch.autograd.Function): def forward( ctx, x: Tensor, + coeffs: Tensor, direction: Tensor, channel_dim: int, - subtract_mean: bool, - max_variance_proportion: float, - grad_scale: float) -> Tuple[Tensor, Tensor]: - eps = 1.0e-20 - num_channels = x.shape[channel_dim] - assert max_variance_proportion > 1.0 / num_channels - orig_x = x - x = x.transpose(channel_dim, -1).reshape(-1, num_channels) - if subtract_mean: - x = x - x.mean(dim=0) - new_direction, coeffs = find_direction_coeffs(x, direction) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + grad_scale: float) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) + return x - ans_direction = direction + new_direction # ensure nonzero even if x == 0 - ans_direction = ans_direction / ans_direction.norm() - - if random.random() < 0.0005: - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(x.shape)}") - - # Caution: this causes a CUDA sync, which is not ideal. - if variance_proportion >= max_variance_proportion: - ctx.channel_dim = channel_dim - ctx.subtract_mean = subtract_mean - ctx.grad_scale = grad_scale - ctx.save_for_backward(orig_x.detach(), - coeffs.detach(), - new_direction.detach()) - - return orig_x, ans_direction @staticmethod def backward(ctx, x_grad, *args): - # the *args is all the other derivs, which should be None or zero. - if not hasattr(ctx, 'channel_dim'): - # the top eig's proportion of the variance was below the threshold. - return x_grad, None, None, None, None, None, None with torch.enable_grad(): (x_orig, coeffs, new_direction) = ctx.saved_tensors x_orig.requires_grad = True num_channels = x_orig.shape[ctx.channel_dim] x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False - if ctx.subtract_mean: - x = x - x.mean(dim=0) + x = x - x.mean(dim=0) x_var = (x ** 2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual ** 2).mean() @@ -212,7 +97,7 @@ class MaxEigLimiterFunction(torch.autograd.Function): variance_proportion.backward() x_orig_grad = x_orig.grad x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) - return x_grad + x_extra_grad.detach(), None, None, None, None, None, None + return x_grad + x_extra_grad.detach(), None, None, None, None class BasicNorm(torch.nn.Module): @@ -352,11 +237,15 @@ class ActivationBalancer(torch.nn.Module): max_abs: the maximum average-absolute-value difference from the mean value per channel, which we allow, before we start to modify the derivatives to prevent this. - max_var_per_eig: the maximum proportion of the variance of the - features/channels, after mean subtraction, that can come from - any given eigenvalue. + beta: a constant used in decaying stats for the {min,max}_positive and + {min,max}_abs constraints. Likely not critical. + prob: determines the probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + stats_period: the periodicity with which we update the statistics on + the activations. """ - def __init__( self, num_channels: int, @@ -367,6 +256,9 @@ class ActivationBalancer(torch.nn.Module): min_abs: float = 0.2, max_abs: float = 100.0, max_var_per_eig: float = 0.0, + beta: float = 0.75, + prob: float = 0.25, + stats_period: int = 10, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -376,49 +268,261 @@ class ActivationBalancer(torch.nn.Module): self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs - assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels - self.max_var_per_eig = max_var_per_eig - if max_var_per_eig > 0.0: - with torch.no_grad(): - # arbitrary.. would use randn() but want to leave the rest of the model's - # random parameters unchanged for comparison - direction = torch.arange(num_channels).to(torch.float) - direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) - else: - self.max_eig_direction = None + self.beta = beta + self.prob = prob + self.stats_period = stats_period + + # count measures how many times the forward() function has been called. + self.count = 0 + + # the mean of the absolute value of the data per channel + self.register_buffer('abs_mean', torch.zeros(num_channels)) + + # the proportion of activations that are positive, per channel. + self.register_buffer('proportion_positive', torch.zeros(num_channels)) + + # `factors` contains two buffers of shape (num_channels,). + # `sign_factor` is an expression that will be used to scale the + # gradients in backprop; it will be 0 if the max_positive and min_positive + # contstraints are satisfied. + # `scale_factor` is an expression that will be used to encourage the + # data to satisfy our min_abs and max_abs constraints; it will be zero if + # all constraints are satisfied. + self.register_buffer('factors', torch.zeros(2, num_channels)) def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or not x.requires_grad: return x - max_eig_prob = 0.25 - if self.max_var_per_eig > 0 and random.random() < max_eig_prob: - with torch.cuda.amp.autocast(enabled=False): - x, new_direction = MaxEigLimiterFunction.apply( - x, self.max_eig_direction, - self.channel_dim, - True, # subtract_mean - self.max_var_per_eig, - self.max_factor / max_eig_prob, - ) - self.max_eig_direction[:] = new_direction.detach() + count = self.count + self.count += 1 - balance_prob = 0.25 - if random.random() < balance_prob: + if count % self.stats_period == 0: + self._update_stats(x, count) + + if random.random() < self.prob: + # The .clone() is in case the forward() gets called multiple times befor + factors = self.factors.clone() + sign_factor = factors[0] + scale_factor = factors[1] return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor / balance_prob, - self.min_abs, - self.max_abs, + x, sign_factor, scale_factor, self.channel_dim, ) else: return x + def _update_stats(self, + x: Tensor, + count: int): + """ + Updates some statistics that we maintain, describing the average activations per + channel. + """ + with torch.no_grad(): + sum_dims = [d for d in range(x.ndim) if d != self.channel_dim] + + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + # the random.random() thing is to split the difference if x is zero, + # between treating it positive or negative + proportion_positive = torch.mean( + ((x > 0) if random.random() < 0.5 else (x >= 0)).to(torch.float32), dim = sum_dims, + ) + + def filter_inf_nan(y): + mask = (y - y != 0) + y.masked_fill_(mask, 0.0) + + filter_inf_nan(x_abs_mean) + + beta = self.beta if count > 0 else 0.0 + self.abs_mean.mul_(beta).add_(x_abs_mean, alpha=(1-beta)) + self.proportion_positive.mul_(beta).add_(proportion_positive, alpha=(1-beta)) + + max_factor = self.max_factor / self.prob + min_positive = self.min_positive + max_positive = self.max_positive + + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if self.proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ((min_positive - self.proportion_positive).relu() * + (max_factor / min_positive)) + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ((self.proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0))) + sign_factor = self.factors[0] + scale_factor = self.factors[1] + sign_factor[:] = factor1 + factor2 + + # the factor of 2.0 below is just to cancel out a factor of 0.5 that gets introduced when, in + # the backprop, we do (xgt0.to(dtype) - 0.5). + # + # scale_factor_scale, on the other hand, is a heuristically chosen value between 0 and 1, + # that we use to make the gradient changes from the 'scale' constraints (min_abs/max_abs) + # less strong than those from the sign constraints. + # + # This is to get rid of a pathology that can happen if, for instance, a + # channel is always positive but is too small (max_positive and min_abs constraints both + # violated). If scale_factor_scale were equal to 1.0, then the gradient changes from the + # min_positive constraint (trying to make the activation more negative) and from the + # min_abs constraint (trying to make the activation more positive) would exactly cancel. + # Instead we make the min_positive constraint stronger, so it first makes the value + # sometimes negative, and only when that is satisfied, can deal with the absolute-value + # constraint. + scale_factor_scale = 0.5 + below_threshold = (self.abs_mean < self.min_abs) + above_threshold = (self.abs_mean > self.max_abs) + scale_factor[:] = ((below_threshold.to(torch.float32) - + above_threshold.to(torch.float32)) + * (max_factor * (2.0 * scale_factor_scale))) + + +class MaxEig(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to discourage + that any given direction in activation space accounts for more than + a specified proportion of the covariance (e.g. 0.2). + + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. + min_prob: the minimum probability with which we apply this during any invocation + of forward(), assuming last time we applied the constraint it was + not active; supplied for speed. + scale: determines the scale with which we modify the gradients, relative + to the existing / unmodified gradients + """ + def __init__( + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, + ): + super(MaxEig, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.scale = scale + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + + # we figure out the dominant direction using the power method: starting with + # a random vector, keep multiplying by the covariance and renormalizing. + with torch.no_grad(): + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) + direction = direction / direction.norm() + self.register_buffer('max_eig_direction', direction) + + self.min_prob = min_prob + # cur_prob is the current probability we'll use to apply the ActivationBalancer. + # We'll regress this towards prob, each tiem we try to apply it and it is not + # active. + self.cur_prob = 1.0 + + + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or + self.max_var_per_eig <= 0 or + random.random() > self.cur_prob): + return x + + with torch.cuda.amp.autocast(enabled=False): + eps = 1.0e-20 + assert x.dtype != torch.float16 + orig_x = x + with torch.no_grad(): + x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) + x = x - x.mean(dim=0) + new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + + # ensure new direction is nonzero even if x == 0, by including `direction`. + self._set_direction(0.1 * self.max_eig_direction + new_direction) + + if random.random() < 0.0005 or __name__ == "__main__": + logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") + + if variance_proportion >= self.max_var_per_eig: + # The constraint is active. Note, we should quite rarely + # reach here, only near the beginning of training if we are + # starting to diverge, should this constraint be active. + cur_prob = self.cur_prob + self.cur_prob = 1.0 # next time, do the update with probability 1.0. + return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, + self.channel_dim, self.scale) + else: + # let self.cur_prob exponentially approach self.min_prob, as + # long as the constraint is inactive. + self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob + return orig_x + + + def _set_direction(self, + direction: Tensor): + """ + Sets self.max_eig_direction to a normalized version of `direction` + """ + direction = direction.detach() + direction = direction / direction.norm() + direction_sum = direction.sum().item() + if direction_sum - direction_sum == 0: # no inf/nan + self.max_eig_direction[:] = direction + else: + logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}") + + + def _find_direction_coeffs(self, + x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + assert prev_direction.shape == (num_channels,) + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) + return cur_direction, coeffs + + + class DoubleSwishFunction(torch.autograd.Function): """ @@ -460,7 +564,8 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) -def _test_max_eig_limiter(): + +def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -471,15 +576,15 @@ def _test_max_eig_limiter(): x.requires_grad = True - y, new_direction = MaxEigLimiterFunction.apply(x, direction, - 1, # channel_dim - True, # subtract_mean - 0.5, # max_variance_proportion - 0.1, # grad_scale - ) + num_channels = 128 + m = MaxEig(num_channels, + 1, # channel_dim + 0.5, # max_var_per_eig + scale=0.1) # grad_scale - cosine = (new_direction * direction).sum() / (new_direction.norm() * direction.norm()) - logging.info(f"Direction cosine = {cosine}") + + for _ in range(4): + y = m(x) y_grad = torch.randn_like(x) y.backward(gradient=y_grad) @@ -494,16 +599,17 @@ def _test_max_eig_limiter(): def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) x = x.detach() x.requires_grad = True m = ActivationBalancer( probs.numel(), channel_dim=0, min_positive=0.05, - max_positive=0.98, + max_positive=0.95, max_factor=0.2, min_abs=0.0, + prob=1.0, ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -531,6 +637,7 @@ def _test_activation_balancer_magnitude(): max_factor=0.2, min_abs=0.2, max_abs=0.8, + prob=1.0, ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -571,7 +678,7 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_max_eig_limiter() + _test_max_eig() _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm()