mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try to implement caching evaluation for memory efficient training
This commit is contained in:
parent
f66c1600f4
commit
d26ee2bf81
@ -20,7 +20,7 @@ from itertools import repeat
|
||||
from typing import Optional, Tuple, Union
|
||||
from functools import reduce
|
||||
import logging
|
||||
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -230,6 +230,53 @@ def random_cast_to_half(x: Tensor,
|
||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||
|
||||
|
||||
class CachingEvalFunction(torch.autograd.Function):
|
||||
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
||||
# that the backward path runs with the same autocast context as the forward pass.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor, m) -> Tensor:
|
||||
"""
|
||||
m might be an nn.Module
|
||||
"""
|
||||
ctx.x_requires_grad = x.requires_grad
|
||||
ctx.m = m
|
||||
# we need any random numbers used in this evaluation and the next evaluation to be identical.
|
||||
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
|
||||
# that matters in the forward pass), e.g. there should be no dropout.
|
||||
ctx.random_state = random.getstate()
|
||||
ctx.save_for_backward(x)
|
||||
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
||||
y = m(x)
|
||||
ctx.save_for_backward(x, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, y_grad: Tensor) -> Tuple[Tensor, None]:
|
||||
x, y = ctx.saved_tensors
|
||||
x.requires_grad = ctx.x_requires_grad
|
||||
m = ctx.m # e.g. a nn.Module
|
||||
|
||||
random_state = random.getstate()
|
||||
# set the state to what we used in the 1st forward pass.
|
||||
random.setstate(ctx.random_state)
|
||||
with torch.enable_grad():
|
||||
y2 = m(x)
|
||||
assert torch.allclose(y, y2, atol=1.0e-02)
|
||||
# this call to backward() should create grads in the module's parameters
|
||||
y.backward(gradient=y_grad)
|
||||
|
||||
# restore the state from before we entered this function
|
||||
random.setstate(random_state)
|
||||
|
||||
return x.grad, None # x.grad will be None if x.requires_grad is False.
|
||||
|
||||
|
||||
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
||||
return CachingEvalFunction.apply(x, m)
|
||||
|
||||
|
||||
class RandomGradFunction(torch.autograd.Function):
|
||||
"""
|
||||
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
||||
|
||||
@ -41,6 +41,7 @@ from scaling import (
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
caching_eval,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
@ -317,7 +318,12 @@ class Zipformer(EncoderInterface):
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
if not torch.jit.is_scripting():
|
||||
# This saves memory during training, at the expense of re-doing the encoder_embed
|
||||
# computation in the backward pass.
|
||||
x = caching_eval(x, self.encoder_embed)
|
||||
else:
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user