Try to implement caching evaluation for memory efficient training

This commit is contained in:
Daniel Povey 2022-12-15 23:06:40 +08:00
parent f66c1600f4
commit d26ee2bf81
2 changed files with 55 additions and 2 deletions

View File

@ -20,7 +20,7 @@ from itertools import repeat
from typing import Optional, Tuple, Union
from functools import reduce
import logging
from torch.cuda.amp import custom_fwd, custom_bwd
import random
import torch
import torch.nn as nn
@ -230,6 +230,53 @@ def random_cast_to_half(x: Tensor,
return torch.where(is_too_small, random_val, x).to(torch.float16)
class CachingEvalFunction(torch.autograd.Function):
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
# that the backward path runs with the same autocast context as the forward pass.
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, m) -> Tensor:
"""
m might be an nn.Module
"""
ctx.x_requires_grad = x.requires_grad
ctx.m = m
# we need any random numbers used in this evaluation and the next evaluation to be identical.
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
# that matters in the forward pass), e.g. there should be no dropout.
ctx.random_state = random.getstate()
ctx.save_for_backward(x)
# we are inside torch.no_grad() here, so the following won't create the computation graph.
y = m(x)
ctx.save_for_backward(x, y)
return y
@staticmethod
@custom_bwd
def backward(ctx, y_grad: Tensor) -> Tuple[Tensor, None]:
x, y = ctx.saved_tensors
x.requires_grad = ctx.x_requires_grad
m = ctx.m # e.g. a nn.Module
random_state = random.getstate()
# set the state to what we used in the 1st forward pass.
random.setstate(ctx.random_state)
with torch.enable_grad():
y2 = m(x)
assert torch.allclose(y, y2, atol=1.0e-02)
# this call to backward() should create grads in the module's parameters
y.backward(gradient=y_grad)
# restore the state from before we entered this function
random.setstate(random_state)
return x.grad, None # x.grad will be None if x.requires_grad is False.
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
return CachingEvalFunction.apply(x, m)
class RandomGradFunction(torch.autograd.Function):
"""
Does nothing in forward pass; in backward pass, gets rid of very small grads using

View File

@ -41,6 +41,7 @@ from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
softmax,
caching_eval,
ScheduledFloat,
FloatLike,
limit_param_value,
@ -317,7 +318,12 @@ class Zipformer(EncoderInterface):
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
if not torch.jit.is_scripting():
# This saves memory during training, at the expense of re-doing the encoder_embed
# computation in the backward pass.
x = caching_eval(x, self.encoder_embed)
else:
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)