diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f8ba1dc54..0f7307c56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -198,120 +198,6 @@ FloatLike = Union[float, ScheduledFloat] -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - mean: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgtmean = (x > mean) - if sign_factor is None: - ctx.save_for_backward(xgtmean, scale_factor) - else: - ctx.save_for_backward(xgtmean, scale_factor, sign_factor) - return x - - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None]: - if len(ctx.saved_tensors) == 3: - xgtmean, scale_factor, sign_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - sign_factor = sign_factor.unsqueeze(-1) - factor = sign_factor + scale_factor * (xgtmean.to(x_grad.dtype) - 0.5) - else: - xgtmean, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - factor = scale_factor * (xgtmean.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return x_grad - neg_delta_grad, None, None, None, None - -def _compute_scale_factor(x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float) -> Tuple[Tensor, Tensor]: - """ - Computes a factor used in ActivationBalancer, that dictates how much we penalize (or anti-penalize) - the scale on the features. - - Returns: (scale_factor, mean) - dim. - scale_factor: can be positive or negative, between -max_factor and max_factor; dictates - penalty or anti-penalty. It is of shape (num_channels,) - mean: mean per channel that we use for purposes of scale_factor; actually is clamped to - -min_abs..min_abs. Its like (1, num_channels, 1, 1) depending on the shape of x and - channel-dim. - - - """ - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - - x_mean = torch.mean(x, dim=sum_dims, keepdim=True).to(torch.float32) - # the idea is that for purposes of applying max_abs, we regress effectively - # toward zero (assuming min_abs is much less than max_abs). - x_mean = x_mean.clamp(min=-min_abs, max=min_abs) - x_abs_mean = torch.mean((x - x_mean).abs(), dim=sum_dims).to(torch.float32) - - if min_abs == 0.0: - below_threshold = 0.0 - else: - # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if - # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) - - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) - - return below_threshold - above_threshold, x_mean - -def _compute_sign_factor(x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), - dim=sum_dims) - if min_positive == 0.0: - factor1 = 0.0 - else: - # 0 if proportion_positive >= min_positive, else can be - # as large as max_factor. - factor1 = ((min_positive - proportion_positive) * - (gain_factor / min_positive)).clamp_(min=0, max=max_factor) - - if max_positive == 1.0: - factor2 = 0.0 - else: - # 0 if self.proportion_positive <= max_positive, else can be - # as large as -max_factor. - factor2 = ((proportion_positive - max_positive) * - (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) - sign_factor = factor1 - factor2 - # require min_positive != 0 or max_positive != 1: - assert not isinstance(sign_factor, float) - return sign_factor - - - - def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: @@ -364,167 +250,6 @@ class CutoffEstimator: -class CachingEvalFunction(torch.autograd.Function): - # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure - # that the backward path runs with the same autocast context as the forward pass. - @staticmethod - @custom_fwd - def forward(ctx, *args): - """ - m might be an nn.Module - """ - tot_num_args = len(args) - orig_num_args = args[0] - orig_args = args[1:1+orig_num_args] - - ctx.num_dummy_args = tot_num_args - orig_num_args - 1 - - tensor_args = [] - non_tensor_args = [] - is_tensor_arg = [] - tensor_requires_grad = [] - for i in range(orig_num_args): - arg = args[1 + i] - is_tensor = isinstance(arg, torch.Tensor) - if is_tensor: - t = arg.detach() - tensor_requires_grad.append(arg.requires_grad) - tensor_args.append(t) - - else: - non_tensor_args.append(arg) - is_tensor_arg.append(is_tensor) - ctx.is_tensor_arg = is_tensor_arg # list of bool - ctx.non_tensor_args = non_tensor_args - ctx.tensor_requires_grad = tensor_requires_grad - ctx.save_for_backward(*tensor_args) - - # m is module, function or lambda. - m = orig_args[0] - # call m with the remaining elements of orig_args - ans = m(*orig_args[1:]) - - return ans - - @staticmethod - @custom_bwd - def backward(ctx, *grads): - with torch.enable_grad(): - tensor_args = ctx.saved_tensors - tensor_requires_grad = ctx.tensor_requires_grad - non_tensor_args = ctx.non_tensor_args - is_tensor_arg = ctx.is_tensor_arg - args = [] - tensor_idx = 0 - non_tensor_idx = 0 - for b in ctx.is_tensor_arg: - if b: - t = tensor_args[tensor_idx] - t.requires_grad = tensor_requires_grad[tensor_idx] - args.append(t) - tensor_idx += 1 - else: - args.append(non_tensor_args[non_tensor_idx]) - non_tensor_idx += 1 - m = args[0] - # ans should the same as the original ans. - ans = m(*args[1:]) - if isinstance(ans, Tensor): - ans = [ans] - # keep only the tensors from ans. - filtered_grads = [] - filtered_ans = [] - assert len(ans) == len(grads) - for i, a in enumerate(ans): - if isinstance(a, Tensor): - filtered_ans.append(a) - filtered_grads.append(grads[i]) - else: - assert grads[i] is None - - torch.autograd.backward(filtered_ans, filtered_grads) - - returned_grads = [ a.grad if isinstance(a, Tensor) else None for a in args ] - - return tuple([None] + returned_grads + [None] * ctx.num_dummy_args) - - -def caching_eval(*args): - """ - A memory-efficient way to evaluate a nn.Module (or function or lambda), that - recomputes the forward pass during the backward pass so we don't have to - store intermediate quantities in the graph. - - Example: - m = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 10)) - y = caching_eval(m, x) - This function will treat the first arg as a function and give it the - remaining args. If m is a lambda, you should not capture - any nn.Module or Tensor arguments in the lambda; instead you should make - them arguments to the lambda. - - m must return a single item or a tuple of items; the items may be Tensors - or other types (Tensor will be treated specially). - - This function returns a single element (probably a Tensor) if m returned - a single element; otherwise it returns a tuple. - """ - dummy_args = [] - for arg in args: - # these dummy_args, the list of parameters, are not going to be - # used directly and no grad will be returned for them; the purpose of - # adding them is to make PyTorch think that a grad might be returned, - # so it doesn't assign a zero grad if the training loop does backward() - # with find_unused_args=True. - if isinstance(arg, nn.Module): - dummy_args = dummy_args + list(arg.parameters()) - orig_num_args = len(args) - - # args we give to the function: n - function_args = [ orig_num_args ] + list(args) + dummy_args - # This function returns a single element (probably a Tensor) or a tuple; - # it returns whatever - return CachingEvalFunction.apply(*function_args) - - - -class RandomGradFunction(torch.autograd.Function): - """ - Does nothing in forward pass; in backward pass, gets rid of very small grads using - randomized approach that preserves expectations (intended to reduce roundoff). - """ - @staticmethod - def forward(ctx, x: Tensor, min_abs: float) -> Tensor: - ctx.min_abs = min_abs - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: - if ans_grad.dtype == torch.float16: - return random_cast_to_half(ans_grad.to(torch.float32), - min_abs=ctx.min_abs), None - else: - return ans_grad, None - -class RandomGrad(torch.nn.Module): - """ - Gets rid of very small gradients using an expectation-preserving method, intended to increase - accuracy of training when using amp (automatic mixed precision) - """ - def __init__(self, - min_abs: float = 5.0e-06): - super(RandomGrad, self).__init__() - self.min_abs = min_abs - - def forward(self, - x: Tensor): - if torch.jit.is_scripting() or not self.training: - return x - else: - return RandomGradFunction.apply(x, self.min_abs) - - - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should @@ -1164,110 +889,6 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - sign_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_positive and max_positive - are violated. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - max_factor: FloatLike = 0.04, - sign_gain_factor: FloatLike = 0.04, - scale_gain_factor: FloatLike = 0.04, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - prob: Optional[FloatLike] = None, - ): - super(ActivationBalancer, self).__init__() - - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - self.sign_gain_factor = sign_gain_factor - self.scale_gain_factor = scale_gain_factor - - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): - return _no_op(x) - - prob = float(self.prob) - - if random.random() < prob: - assert x.shape[self.channel_dim] == self.num_channels - sign_gain_factor = 0.5 - if float(self.min_positive) != 0.0 or float(self.max_positive) != 1.0: - sign_factor = _compute_sign_factor(x.detach(), self.channel_dim, - float(self.min_positive), - float(self.max_positive), - gain_factor=float(self.sign_gain_factor) / prob, - max_factor=float(self.max_factor)) - else: - sign_factor = None - - - scale_factor, mean = _compute_scale_factor(x.detach(), self.channel_dim, - min_abs=float(self.min_abs), - max_abs=float(self.max_abs), - gain_factor=float(self.scale_gain_factor) / prob, - max_factor=float(self.max_factor)) - return ActivationBalancerFunction.apply( - x, scale_factor, mean, sign_factor, self.channel_dim, - ) - else: - return _no_op(x) - class BalancerFunction(torch.autograd.Function): @@ -1715,145 +1336,6 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) -class MaxEig(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to discourage - that any given direction in activation space accounts for more than - a specified proportion of the covariance (e.g. 0.2). - - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - max_var_per_eig: the maximum proportion of the variance of the - features/channels, after mean subtraction, that can come from - any given eigenvalue. - min_prob: the minimum probability with which we apply this during any invocation - of forward(), assuming last time we applied the constraint it was - not active; supplied for speed. - scale: determines the scale with which we modify the gradients, relative - to the existing / unmodified gradients - """ - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, - ): - super(MaxEig, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.scale = scale - assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels - self.max_var_per_eig = max_var_per_eig - - # we figure out the dominant direction using the power method: starting with - # a random vector, keep multiplying by the covariance and renormalizing. - with torch.no_grad(): - # arbitrary.. would use randn() but want to leave the rest of the model's - # random parameters unchanged for comparison - direction = torch.arange(num_channels).to(torch.float) - direction = direction / direction.norm() - self.register_buffer('max_eig_direction', direction) - - self.min_prob = min_prob - # cur_prob is the current probability we'll use to apply the ActivationBalancer. - # We'll regress this towards prob, each time we try to apply it and it is not - # active. - self.cur_prob = 1.0 - - - - def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or - self.max_var_per_eig <= 0 or - random.random() > self.cur_prob): - return _no_op(x) - - with torch.cuda.amp.autocast(enabled=False): - eps = 1.0e-20 - orig_x = x - x = x.to(torch.float32) - with torch.no_grad(): - x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) - x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - - # ensure new direction is nonzero even if x == 0, by including `direction`. - self._set_direction(0.1 * self.max_eig_direction + new_direction) - - if random.random() < 0.01 or __name__ == "__main__": - logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") - - if variance_proportion >= self.max_var_per_eig: - # The constraint is active. Note, we should quite rarely - # reach here, only near the beginning of training if we are - # starting to diverge, should this constraint be active. - cur_prob = self.cur_prob - self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, - self.channel_dim, self.scale) - else: - # let self.cur_prob exponentially approach self.min_prob, as - # long as the constraint is inactive. - self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob - return orig_x - - - def _set_direction(self, - direction: Tensor): - """ - Sets self.max_eig_direction to a normalized version of `direction` - """ - direction = direction.detach() - direction = direction / direction.norm() - direction_sum = direction.sum().item() - if direction_sum - direction_sum == 0: # no inf/nan - self.max_eig_direction[:] = direction - else: - logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}") - - - def _find_direction_coeffs(self, - x: Tensor, - prev_direction: Tensor) -> Tuple[Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ - (num_frames, num_channels) = x.shape - assert num_channels > 1 and num_frames > 1 - assert prev_direction.shape == (num_channels,) - # `coeffs` are the coefficients of `prev_direction` in x. - # actually represent the coeffs up to a constant positive factor. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) - return cur_direction, coeffs - - - class DoubleSwishFunction(torch.autograd.Function): """ @@ -1928,67 +1410,6 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) -class TanSwishFunction(torch.autograd.Function): - """ - double_swish(x) = tan(x) * torch.sigmoid(x-1) - - - entering: d/dx(tanh(x) * sigmoid(x-1)) - into wolfram alpha, I see that the range of this function is - -0.0498087 <= y <= 0.417894 - let's make it (as we don't know how this was rounded): - -0.0498088 <= y <= 0.417895 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if not requires_grad: - return torch.tanh(x) * torch.sigmoid(x - 1.0) - - x_dtype = x.dtype - if x.dtype == torch.float16: - x = x.to(torch.float32) - - with torch.cuda.amp.autocast(enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.tanh(x) * torch.sigmoid(x - 1.0) - y.backward(gradient=torch.ones_like(y)) - grad = x.grad - floor = -0.0498088 - ceil = 0.417895 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.0498088 - ceil = 0.417895 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) - - -class TanSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return tan-swish activation function which is tanh(x) sigmoid(x-1)n - """ - if torch.jit.is_scripting(): - return x.tanh() * torch.sigmoid(x - 1.0) - return TanSwishFunction.apply(x) - # Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. class Dropout2(nn.Module): @@ -2177,35 +1598,6 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: -def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = MaxEig(num_channels, - 1, # channel_dim - 0.5, # max_var_per_eig - scale=0.1) # grad_scale - - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad, atol=1.0e-02) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - def _test_whiten(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"_test_whiten(): proportion = {proportion}") @@ -2236,30 +1628,6 @@ def _test_whiten(): -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - def _test_balancer_sign(): probs = torch.arange(0, 1, 0.01) @@ -2286,33 +1654,6 @@ def _test_balancer_sign(): -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) @@ -2366,20 +1707,6 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - -def _test_tan_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = TanSwish() - - tol = ((1.2-(-0.043637))/255.0) - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True @@ -2487,16 +1814,11 @@ if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_piecewise_linear() - _test_caching_eval() _test_softmax() _test_whiten() - _test_max_eig() - _test_activation_balancer_sign() _test_balancer_sign() - _test_activation_balancer_magnitude() _test_balancer_magnitude() _test_basic_norm() _test_double_swish_deriv() - _test_tan_swish_deriv() _test_swooshr_deriv() _test_swooshl_deriv() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c49c012ba..fe8ebb8b8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -25,18 +25,14 @@ import torch import random from encoder_interface import EncoderInterface from scaling import ( - ActivationBalancer, Balancer, BasicNorm, ConvNorm1d, ConvNorm2d, Dropout2, Dropout3, - MaxEig, - DoubleSwish, SwooshL, SwooshR, - TanSwish, ChunkCausalDepthwiseConv1d, ScaledConv1d, ScaledConv2d, @@ -45,7 +41,6 @@ from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. penalize_abs_values_gt, softmax, - caching_eval, ScheduledFloat, FloatLike, limit_param_value, @@ -160,7 +155,6 @@ class Zipformer(EncoderInterface): self.num_features = num_features # int self.output_downsampling_factor = output_downsampling_factor # int self.downsampling_factor = downsampling_factor # tuple - self.downsampling_factor_gcd = next(n for n in range(1, 10000) if all(n % d == 0 for d in downsampling_factor)) self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple num_encoder_layers = _to_tuple(num_encoder_layers) @@ -1506,105 +1500,6 @@ class SelfAttention(nn.Module): return x -class AttentionSqueeze(nn.Module): - """ - A modified version of Squeeze-and-Excite, where the nonliearity happens in the full dim and - we just project to a small bottleneck dimension. - """ - def __init__(self, - embed_dim: int, - hidden_dim: int, - bottleneck_dim: int = 16): - super().__init__() - - self.lr_scale = 0.9 - - self.bottleneck_dim = bottleneck_dim - - self.in_proj = nn.Linear(embed_dim, hidden_dim, - bias=False) - - self.to_bottleneck_proj = nn.Linear(embed_dim, - bottleneck_dim) - - # bottleneck_balancer is before the activation. Mostly, for well-trained - # instances of this module, the mean absolute values per channel are in - # the range 0.1 to 0.4. We apply the upper limit of 0.4 at the - # beginning, and make it looser over time. - self.bottleneck_balancer = Balancer( - bottleneck_dim, channel_dim=-1, - min_positive=0.2, max_positive=0.8, - min_abs=0.05, - max_abs=ScheduledFloat((0.0, 0.5), (4000.0, 1.0), default=1.0), - ) - self.bottleneck_activation = TanSwish() # in bottleneck - self.activation = Identity() # for diagnostics - - # the reason for the min_abs and max_abs limits on the next two - # balancers are only to stop parameter-magnitude 'drift': we have too - # many degrees of freedom for the scales of the various activations. - # Make them run with very low probability, since only a small - # application of these balancers should be enough to stop such "drift". - self.scale_balancer = Balancer( - hidden_dim, channel_dim=-1, - min_positive=0.2, max_positive=0.8, - min_abs=0.2, max_abs=1.0, - prob=_balancer_schedule(0.05), - ) - self.activation_balancer = Balancer( - hidden_dim, channel_dim=-1, - min_positive=0.2, max_positive=0.8, - min_abs=0.2, max_abs=1.0, - prob=_balancer_schedule(0.05), - ) - self.activation_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01) - - - self.from_bottleneck_proj = nn.Linear(bottleneck_dim, hidden_dim) - - self.out_proj = ScaledLinear(hidden_dim, embed_dim, - bias=False, initial_scale=0.05) - - def forward(self, - x: Tensor, - attn_weights: Tensor): - """ - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - num_heads = attn_weights.shape[0] - bottleneck = self.to_bottleneck_proj(x) # (seq_len, batch_size, bottleneck_dim) - (seq_len, batch_size, bottleneck_dim) = bottleneck.shape - head_dim = bottleneck_dim // num_heads - bottleneck = bottleneck.reshape(seq_len, batch_size, num_heads, head_dim).permute( - 2, 1, 0, 3) # (num_heads, batch_size, seq_len, head_dim) - - # (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim) - # -> (num_heads, batch_size, seq_len, head_dim) - bottleneck = torch.matmul(attn_weights, bottleneck) - - bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim) - bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim) - - bottleneck = self.bottleneck_balancer(bottleneck) - bottleneck = self.bottleneck_activation(bottleneck) - scales = self.from_bottleneck_proj(bottleneck) - - x = self.in_proj(x) - x = self.activation_balancer(x) - x = self.activation_whiten(x) - scales = self.scale_balancer(scales) - x = x * scales - x = self.activation(x) # Identity only. For diagnostics. - x = self.out_proj(x) - return x - class FeedforwardModule(nn.Module): """Feedforward module in Zipformer model.