From b3527fe4ac26d3ba6dfdb7207f88ef55d6179277 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 7 Jan 2023 17:31:20 +0800 Subject: [PATCH] Implement caching evaluation for ConvNeXt --- .../pruned_transducer_stateless7/scaling.py | 172 ++++++++++++++---- .../pruned_transducer_stateless7/zipformer.py | 23 ++- 2 files changed, 151 insertions(+), 44 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index ef998eb4a..82d166c50 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -264,60 +264,129 @@ class CutoffEstimator: return ans + class CachingEvalFunction(torch.autograd.Function): # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure # that the backward path runs with the same autocast context as the forward pass. @staticmethod @custom_fwd - def forward(ctx, x: Tensor, m) -> Tensor: + def forward(ctx, *args): """ m might be an nn.Module """ - ctx.x_requires_grad = x.requires_grad - ctx.m = m - # we need any random numbers used in this evaluation and the next evaluation to be identical. - # Caution: this assumes you are not going to use any random numbers from torch (for any purpose - # that matters in the forward pass), e.g. there should be no dropout. - ctx.random_state = random.getstate() - # we are inside torch.no_grad() here, so the following won't create the computation graph. - with torch.no_grad(): - y = m(x) - ctx.save_for_backward(x, y) - return y + tot_num_args = len(args) + orig_num_args = args[0] + orig_args = args[1:1+orig_num_args] + + ctx.num_dummy_args = tot_num_args - orig_num_args - 1 + + tensor_args = [] + non_tensor_args = [] + is_tensor_arg = [] + tensor_requires_grad = [] + for i in range(orig_num_args): + arg = args[1 + i] + is_tensor = isinstance(arg, torch.Tensor) + if is_tensor: + t = arg.detach() + tensor_requires_grad.append(arg.requires_grad) + tensor_args.append(t) + + else: + non_tensor_args.append(arg) + is_tensor_arg.append(is_tensor) + ctx.is_tensor_arg = is_tensor_arg # list of bool + ctx.non_tensor_args = non_tensor_args + ctx.tensor_requires_grad = tensor_requires_grad + ctx.save_for_backward(*tensor_args) + + # m is module, function or lambda. + m = orig_args[0] + # call m with the remaining elements of orig_args + ans = m(*orig_args[1:]) + + return ans @staticmethod @custom_bwd - def backward(ctx, y_grad: Tensor): - x, y = ctx.saved_tensors - x = x.detach() - x.requires_grad = ctx.x_requires_grad - m = ctx.m # e.g. a nn.Module - - random_state = random.getstate() - # set the state to what we used in the 1st forward pass. - random.setstate(ctx.random_state) + def backward(ctx, *grads): with torch.enable_grad(): - y2 = m(x) - assert torch.allclose(y, y2, atol=1.0e-02) - # this call to backward() should create grads in the module's parameters - y2.backward(gradient=y_grad) + tensor_args = ctx.saved_tensors + tensor_requires_grad = ctx.tensor_requires_grad + non_tensor_args = ctx.non_tensor_args + is_tensor_arg = ctx.is_tensor_arg + args = [] + tensor_idx = 0 + non_tensor_idx = 0 + for b in ctx.is_tensor_arg: + if b: + t = tensor_args[tensor_idx] + t.requires_grad = tensor_requires_grad[tensor_idx] + args.append(t) + tensor_idx += 1 + else: + args.append(non_tensor_args[non_tensor_idx]) + non_tensor_idx += 1 + m = args[0] + # ans should the same as the original ans. + ans = m(*args[1:]) + if isinstance(ans, Tensor): + ans = [ans] + # keep only the tensors from ans. + filtered_grads = [] + filtered_ans = [] + assert len(ans) == len(grads) + for i, a in enumerate(ans): + if isinstance(a, Tensor): + filtered_ans.append(a) + filtered_grads.append(grads[i]) + else: + assert grads[i] is None - # restore the state from before we entered this function - random.setstate(random_state) + torch.autograd.backward(filtered_ans, filtered_grads) - return x.grad, None # x.grad will be None if x.requires_grad is False. + returned_grads = [ a.grad if isinstance(a, Tensor) else None for a in args ] + + return tuple([None] + returned_grads + [None] * ctx.num_dummy_args) -def caching_eval(x: Tensor, m: nn.Module) -> Tensor: - if m.training: - # The purpose of this code is to make all parameters of m reachable in - # the computation graph, so that if we give find_unused_parameters=True - # to PyTorch's autograd code it does not assign them zero gradient. - tot = 0.0 - for p in m.parameters(): - tot = tot + 0.0 * p.flatten()[0] - x = x + tot # tot will be 0, this does nothing. - return CachingEvalFunction.apply(x, m) +def caching_eval(*args): + """ + A memory-efficient way to evaluate a nn.Module (or function or lambda), that + recomputes the forward pass during the backward pass so we don't have to + store intermediate quantities in the graph. + + Example: + m = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 10)) + y = caching_eval(m, x) + This function will treat the first arg as a function and give it the + remaining args. If m is a lambda, you should not capture + any nn.Module or Tensor arguments in the lambda; instead you should make + them arguments to the lambda. + + m must return a single item or a tuple of items; the items may be Tensors + or other types (Tensor will be treated specially). + + This function returns a single element (probably a Tensor) if m returned + a single element; otherwise it returns a tuple. + """ + dummy_args = [] + for arg in args: + # these dummy_args, the list of parameters, are not going to be + # used directly and no grad will be returned for them; the purpose of + # adding them is to make PyTorch think that a grad might be returned, + # so it doesn't assign a zero grad if the training loop does backward() + # with find_unused_args=True. + if isinstance(arg, nn.Module): + dummy_args = dummy_args + list(arg.parameters()) + orig_num_args = len(args) + + # args we give to the function: n + function_args = [ orig_num_args ] + list(args) + dummy_args + # This function returns a single element (probably a Tensor) or a tuple; + # it returns whatever + return CachingEvalFunction.apply(*function_args) + class RandomGradFunction(torch.autograd.Function): @@ -2082,11 +2151,38 @@ def _test_softmax(): assert torch.allclose(a.grad, b.grad) +def _test_caching_eval(): + m = nn.Sequential(nn.Linear(10, 100, bias=False), nn.ReLU(), nn.Linear(100, 10, bias=False)) + + x = torch.randn(50, 10) + y_grad = torch.randn(50, 10) + x.requires_grad = True + + y1 = m(x) + y1.backward(gradient=y_grad) + x_grad1 = x.grad + x.grad = None + weight_grad1a = m[0].weight.grad + weight_grad1b = m[2].weight.grad + m[0].weight.grad = None + m[2].weight.grad = None + m.zero_grad() + + y2 = caching_eval(m, x) + assert torch.allclose(y1, y2) + y2.backward(gradient=y_grad) + assert torch.allclose(x.grad, x_grad1) + assert torch.allclose(m[0].weight.grad, weight_grad1a) + assert torch.allclose(m[2].weight.grad, weight_grad1b) + + + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_caching_eval() _test_softmax() _test_whiten() _test_max_eig() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index d37910b12..6945f5151 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1853,8 +1853,22 @@ class ConvNeXt(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate + else: + mask = None + return caching_eval(self.forward_internal, x, mask) + + + def forward_internal(self, + x: Tensor, + layer_skip_mask: Optional[Tensor] = None) -> Tensor: """ x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) @@ -1867,11 +1881,8 @@ class ConvNeXt(nn.Module): x = self.activation(x) x = self.pointwise_conv2(x) - layerdrop_rate = float(self.layerdrop_rate) - if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate - x = x * mask + if layer_skip_mask is not None: + x = x * layer_skip_mask x = bypass + x x = self.out_balancer(x)