From d26ee2bf81606210381d10b261608c19f698aaa6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 15 Dec 2022 23:06:40 +0800 Subject: [PATCH] Try to implement caching evaluation for memory efficient training --- .../pruned_transducer_stateless7/scaling.py | 49 ++++++++++++++++++- .../pruned_transducer_stateless7/zipformer.py | 8 ++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index f396b5c65..4febd2034 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -20,7 +20,7 @@ from itertools import repeat from typing import Optional, Tuple, Union from functools import reduce import logging - +from torch.cuda.amp import custom_fwd, custom_bwd import random import torch import torch.nn as nn @@ -230,6 +230,53 @@ def random_cast_to_half(x: Tensor, return torch.where(is_too_small, random_val, x).to(torch.float16) +class CachingEvalFunction(torch.autograd.Function): + # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure + # that the backward path runs with the same autocast context as the forward pass. + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, m) -> Tensor: + """ + m might be an nn.Module + """ + ctx.x_requires_grad = x.requires_grad + ctx.m = m + # we need any random numbers used in this evaluation and the next evaluation to be identical. + # Caution: this assumes you are not going to use any random numbers from torch (for any purpose + # that matters in the forward pass), e.g. there should be no dropout. + ctx.random_state = random.getstate() + ctx.save_for_backward(x) + # we are inside torch.no_grad() here, so the following won't create the computation graph. + y = m(x) + ctx.save_for_backward(x, y) + return y + + @staticmethod + @custom_bwd + def backward(ctx, y_grad: Tensor) -> Tuple[Tensor, None]: + x, y = ctx.saved_tensors + x.requires_grad = ctx.x_requires_grad + m = ctx.m # e.g. a nn.Module + + random_state = random.getstate() + # set the state to what we used in the 1st forward pass. + random.setstate(ctx.random_state) + with torch.enable_grad(): + y2 = m(x) + assert torch.allclose(y, y2, atol=1.0e-02) + # this call to backward() should create grads in the module's parameters + y.backward(gradient=y_grad) + + # restore the state from before we entered this function + random.setstate(random_state) + + return x.grad, None # x.grad will be None if x.requires_grad is False. + + +def caching_eval(x: Tensor, m: nn.Module) -> Tensor: + return CachingEvalFunction.apply(x, m) + + class RandomGradFunction(torch.autograd.Function): """ Does nothing in forward pass; in backward pass, gets rid of very small grads using diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 9f445640c..5f49f220e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -41,6 +41,7 @@ from scaling import ( Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. penalize_abs_values_gt, softmax, + caching_eval, ScheduledFloat, FloatLike, limit_param_value, @@ -317,7 +318,12 @@ class Zipformer(EncoderInterface): - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ - x = self.encoder_embed(x) + if not torch.jit.is_scripting(): + # This saves memory during training, at the expense of re-doing the encoder_embed + # computation in the backward pass. + x = caching_eval(x, self.encoder_embed) + else: + x = self.encoder_embed(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)