diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 5d88aeccc..5af2402c0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -363,60 +363,129 @@ class CutoffEstimator: return ans + class CachingEvalFunction(torch.autograd.Function): # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure # that the backward path runs with the same autocast context as the forward pass. @staticmethod @custom_fwd - def forward(ctx, x: Tensor, m) -> Tensor: + def forward(ctx, *args): """ m might be an nn.Module """ - ctx.x_requires_grad = x.requires_grad - ctx.m = m - # we need any random numbers used in this evaluation and the next evaluation to be identical. - # Caution: this assumes you are not going to use any random numbers from torch (for any purpose - # that matters in the forward pass), e.g. there should be no dropout. - ctx.random_state = random.getstate() - # we are inside torch.no_grad() here, so the following won't create the computation graph. - with torch.no_grad(): - y = m(x) - ctx.save_for_backward(x, y) - return y + tot_num_args = len(args) + orig_num_args = args[0] + orig_args = args[1:1+orig_num_args] + + ctx.num_dummy_args = tot_num_args - orig_num_args - 1 + + tensor_args = [] + non_tensor_args = [] + is_tensor_arg = [] + tensor_requires_grad = [] + for i in range(orig_num_args): + arg = args[1 + i] + is_tensor = isinstance(arg, torch.Tensor) + if is_tensor: + t = arg.detach() + tensor_requires_grad.append(arg.requires_grad) + tensor_args.append(t) + + else: + non_tensor_args.append(arg) + is_tensor_arg.append(is_tensor) + ctx.is_tensor_arg = is_tensor_arg # list of bool + ctx.non_tensor_args = non_tensor_args + ctx.tensor_requires_grad = tensor_requires_grad + ctx.save_for_backward(*tensor_args) + + # m is module, function or lambda. + m = orig_args[0] + # call m with the remaining elements of orig_args + ans = m(*orig_args[1:]) + + return ans @staticmethod @custom_bwd - def backward(ctx, y_grad: Tensor): - x, y = ctx.saved_tensors - x = x.detach() - x.requires_grad = ctx.x_requires_grad - m = ctx.m # e.g. a nn.Module - - random_state = random.getstate() - # set the state to what we used in the 1st forward pass. - random.setstate(ctx.random_state) + def backward(ctx, *grads): with torch.enable_grad(): - y2 = m(x) - assert torch.allclose(y, y2, atol=1.0e-02) - # this call to backward() should create grads in the module's parameters - y2.backward(gradient=y_grad) + tensor_args = ctx.saved_tensors + tensor_requires_grad = ctx.tensor_requires_grad + non_tensor_args = ctx.non_tensor_args + is_tensor_arg = ctx.is_tensor_arg + args = [] + tensor_idx = 0 + non_tensor_idx = 0 + for b in ctx.is_tensor_arg: + if b: + t = tensor_args[tensor_idx] + t.requires_grad = tensor_requires_grad[tensor_idx] + args.append(t) + tensor_idx += 1 + else: + args.append(non_tensor_args[non_tensor_idx]) + non_tensor_idx += 1 + m = args[0] + # ans should the same as the original ans. + ans = m(*args[1:]) + if isinstance(ans, Tensor): + ans = [ans] + # keep only the tensors from ans. + filtered_grads = [] + filtered_ans = [] + assert len(ans) == len(grads) + for i, a in enumerate(ans): + if isinstance(a, Tensor): + filtered_ans.append(a) + filtered_grads.append(grads[i]) + else: + assert grads[i] is None - # restore the state from before we entered this function - random.setstate(random_state) + torch.autograd.backward(filtered_ans, filtered_grads) - return x.grad, None # x.grad will be None if x.requires_grad is False. + returned_grads = [ a.grad if isinstance(a, Tensor) else None for a in args ] + + return tuple([None] + returned_grads + [None] * ctx.num_dummy_args) -def caching_eval(x: Tensor, m: nn.Module) -> Tensor: - if m.training: - # The purpose of this code is to make all parameters of m reachable in - # the computation graph, so that if we give find_unused_parameters=True - # to PyTorch's autograd code it does not assign them zero gradient. - tot = 0.0 - for p in m.parameters(): - tot = tot + 0.0 * p.flatten()[0] - x = x + tot # tot will be 0, this does nothing. - return CachingEvalFunction.apply(x, m) +def caching_eval(*args): + """ + A memory-efficient way to evaluate a nn.Module (or function or lambda), that + recomputes the forward pass during the backward pass so we don't have to + store intermediate quantities in the graph. + + Example: + m = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 10)) + y = caching_eval(m, x) + This function will treat the first arg as a function and give it the + remaining args. If m is a lambda, you should not capture + any nn.Module or Tensor arguments in the lambda; instead you should make + them arguments to the lambda. + + m must return a single item or a tuple of items; the items may be Tensors + or other types (Tensor will be treated specially). + + This function returns a single element (probably a Tensor) if m returned + a single element; otherwise it returns a tuple. + """ + dummy_args = [] + for arg in args: + # these dummy_args, the list of parameters, are not going to be + # used directly and no grad will be returned for them; the purpose of + # adding them is to make PyTorch think that a grad might be returned, + # so it doesn't assign a zero grad if the training loop does backward() + # with find_unused_args=True. + if isinstance(arg, nn.Module): + dummy_args = dummy_args + list(arg.parameters()) + orig_num_args = len(args) + + # args we give to the function: n + function_args = [ orig_num_args ] + list(args) + dummy_args + # This function returns a single element (probably a Tensor) or a tuple; + # it returns whatever + return CachingEvalFunction.apply(*function_args) + class RandomGradFunction(torch.autograd.Function): @@ -2181,6 +2250,32 @@ def _test_softmax(): assert torch.allclose(a.grad, b.grad) +def _test_caching_eval(): + m = nn.Sequential(nn.Linear(10, 100, bias=False), nn.ReLU(), nn.Linear(100, 10, bias=False)) + + x = torch.randn(50, 10) + y_grad = torch.randn(50, 10) + x.requires_grad = True + + y1 = m(x) + y1.backward(gradient=y_grad) + x_grad1 = x.grad + x.grad = None + weight_grad1a = m[0].weight.grad + weight_grad1b = m[2].weight.grad + m[0].weight.grad = None + m[2].weight.grad = None + m.zero_grad() + + y2 = caching_eval(m, x) + assert torch.allclose(y1, y2) + y2.backward(gradient=y_grad) + assert torch.allclose(x.grad, x_grad1) + assert torch.allclose(m[0].weight.grad, weight_grad1a) + assert torch.allclose(m[2].weight.grad, weight_grad1b) + + + def _test_piecewise_linear(): p = PiecewiseLinear( (0, 10.0) ) @@ -2217,6 +2312,7 @@ if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_piecewise_linear() + _test_caching_eval() _test_softmax() _test_whiten() _test_max_eig() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 28ecd718d..6f6e9137d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -100,8 +100,9 @@ def get_adjusted_batch_count( params: AttributeDict) -> float: # returns the number of batches we would have used so far if we had used the reference # duration. This is for purposes of set_batch_count(). - return (params.batch_idx_train * params.ref_duration / - (params.max_duration * params.world_size)) + return (params.batch_idx_train * (params.max_duration * params.world_size) / + params.ref_duration) + @@ -122,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="4,4,4,4,4,4", + default="4,4,6,4", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -130,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--downsampling-factor", type=str, - default="1,2,4,8,4,2", + default="1,2,4,2", help="Downsampling factor for each stack of encoder layers.", ) @@ -138,14 +139,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="1536,1536,1536,1536,1536,1536", + default="1536,1536,1536,1536", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) parser.add_argument( "--num-heads", type=str, - default="8,8,8,16,8,8", + default="8,8,8,8", help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index aca34b568..4cfd64595 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1862,8 +1862,22 @@ class ConvNeXt(nn.Module): prob=(0.025, 0.25), grad_scale=0.01) - def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate + else: + mask = None + return caching_eval(self.forward_internal, x, mask) + + + def forward_internal(self, + x: Tensor, + layer_skip_mask: Optional[Tensor] = None) -> Tensor: """ x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) @@ -1876,11 +1890,8 @@ class ConvNeXt(nn.Module): x = self.activation(x) x = self.pointwise_conv2(x) - layerdrop_rate = float(self.layerdrop_rate) - if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate - x = x * mask + if layer_skip_mask is not None: + x = x * layer_skip_mask x = bypass + x x = self.out_balancer(x)