Try to implement caching evaluation for memory efficient training

This commit is contained in:
Daniel Povey 2022-12-15 23:06:40 +08:00
parent f66c1600f4
commit d26ee2bf81
2 changed files with 55 additions and 2 deletions

View File

@ -20,7 +20,7 @@ from itertools import repeat
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from functools import reduce from functools import reduce
import logging import logging
from torch.cuda.amp import custom_fwd, custom_bwd
import random import random
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -230,6 +230,53 @@ def random_cast_to_half(x: Tensor,
return torch.where(is_too_small, random_val, x).to(torch.float16) return torch.where(is_too_small, random_val, x).to(torch.float16)
class CachingEvalFunction(torch.autograd.Function):
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
# that the backward path runs with the same autocast context as the forward pass.
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, m) -> Tensor:
"""
m might be an nn.Module
"""
ctx.x_requires_grad = x.requires_grad
ctx.m = m
# we need any random numbers used in this evaluation and the next evaluation to be identical.
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
# that matters in the forward pass), e.g. there should be no dropout.
ctx.random_state = random.getstate()
ctx.save_for_backward(x)
# we are inside torch.no_grad() here, so the following won't create the computation graph.
y = m(x)
ctx.save_for_backward(x, y)
return y
@staticmethod
@custom_bwd
def backward(ctx, y_grad: Tensor) -> Tuple[Tensor, None]:
x, y = ctx.saved_tensors
x.requires_grad = ctx.x_requires_grad
m = ctx.m # e.g. a nn.Module
random_state = random.getstate()
# set the state to what we used in the 1st forward pass.
random.setstate(ctx.random_state)
with torch.enable_grad():
y2 = m(x)
assert torch.allclose(y, y2, atol=1.0e-02)
# this call to backward() should create grads in the module's parameters
y.backward(gradient=y_grad)
# restore the state from before we entered this function
random.setstate(random_state)
return x.grad, None # x.grad will be None if x.requires_grad is False.
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
return CachingEvalFunction.apply(x, m)
class RandomGradFunction(torch.autograd.Function): class RandomGradFunction(torch.autograd.Function):
""" """
Does nothing in forward pass; in backward pass, gets rid of very small grads using Does nothing in forward pass; in backward pass, gets rid of very small grads using

View File

@ -41,6 +41,7 @@ from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
caching_eval,
ScheduledFloat, ScheduledFloat,
FloatLike, FloatLike,
limit_param_value, limit_param_value,
@ -317,7 +318,12 @@ class Zipformer(EncoderInterface):
- lengths, a tensor of shape (batch_size,) containing the number - lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding. of frames in `embeddings` before padding.
""" """
x = self.encoder_embed(x) if not torch.jit.is_scripting():
# This saves memory during training, at the expense of re-doing the encoder_embed
# computation in the backward pass.
x = caching_eval(x, self.encoder_embed)
else:
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)