Implement caching evaluation for ConvNeXt

This commit is contained in:
Daniel Povey 2023-01-07 17:31:20 +08:00
parent 9242800d42
commit b3527fe4ac
2 changed files with 151 additions and 44 deletions

View File

@ -264,60 +264,129 @@ class CutoffEstimator:
return ans
class CachingEvalFunction(torch.autograd.Function):
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
# that the backward path runs with the same autocast context as the forward pass.
@staticmethod
@custom_fwd
def forward(ctx, x: Tensor, m) -> Tensor:
def forward(ctx, *args):
"""
m might be an nn.Module
"""
ctx.x_requires_grad = x.requires_grad
ctx.m = m
# we need any random numbers used in this evaluation and the next evaluation to be identical.
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
# that matters in the forward pass), e.g. there should be no dropout.
ctx.random_state = random.getstate()
# we are inside torch.no_grad() here, so the following won't create the computation graph.
with torch.no_grad():
y = m(x)
ctx.save_for_backward(x, y)
return y
tot_num_args = len(args)
orig_num_args = args[0]
orig_args = args[1:1+orig_num_args]
ctx.num_dummy_args = tot_num_args - orig_num_args - 1
tensor_args = []
non_tensor_args = []
is_tensor_arg = []
tensor_requires_grad = []
for i in range(orig_num_args):
arg = args[1 + i]
is_tensor = isinstance(arg, torch.Tensor)
if is_tensor:
t = arg.detach()
tensor_requires_grad.append(arg.requires_grad)
tensor_args.append(t)
else:
non_tensor_args.append(arg)
is_tensor_arg.append(is_tensor)
ctx.is_tensor_arg = is_tensor_arg # list of bool
ctx.non_tensor_args = non_tensor_args
ctx.tensor_requires_grad = tensor_requires_grad
ctx.save_for_backward(*tensor_args)
# m is module, function or lambda.
m = orig_args[0]
# call m with the remaining elements of orig_args
ans = m(*orig_args[1:])
return ans
@staticmethod
@custom_bwd
def backward(ctx, y_grad: Tensor):
x, y = ctx.saved_tensors
x = x.detach()
x.requires_grad = ctx.x_requires_grad
m = ctx.m # e.g. a nn.Module
random_state = random.getstate()
# set the state to what we used in the 1st forward pass.
random.setstate(ctx.random_state)
def backward(ctx, *grads):
with torch.enable_grad():
y2 = m(x)
assert torch.allclose(y, y2, atol=1.0e-02)
# this call to backward() should create grads in the module's parameters
y2.backward(gradient=y_grad)
tensor_args = ctx.saved_tensors
tensor_requires_grad = ctx.tensor_requires_grad
non_tensor_args = ctx.non_tensor_args
is_tensor_arg = ctx.is_tensor_arg
args = []
tensor_idx = 0
non_tensor_idx = 0
for b in ctx.is_tensor_arg:
if b:
t = tensor_args[tensor_idx]
t.requires_grad = tensor_requires_grad[tensor_idx]
args.append(t)
tensor_idx += 1
else:
args.append(non_tensor_args[non_tensor_idx])
non_tensor_idx += 1
m = args[0]
# ans should the same as the original ans.
ans = m(*args[1:])
if isinstance(ans, Tensor):
ans = [ans]
# keep only the tensors from ans.
filtered_grads = []
filtered_ans = []
assert len(ans) == len(grads)
for i, a in enumerate(ans):
if isinstance(a, Tensor):
filtered_ans.append(a)
filtered_grads.append(grads[i])
else:
assert grads[i] is None
# restore the state from before we entered this function
random.setstate(random_state)
torch.autograd.backward(filtered_ans, filtered_grads)
return x.grad, None # x.grad will be None if x.requires_grad is False.
returned_grads = [ a.grad if isinstance(a, Tensor) else None for a in args ]
return tuple([None] + returned_grads + [None] * ctx.num_dummy_args)
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
if m.training:
# The purpose of this code is to make all parameters of m reachable in
# the computation graph, so that if we give find_unused_parameters=True
# to PyTorch's autograd code it does not assign them zero gradient.
tot = 0.0
for p in m.parameters():
tot = tot + 0.0 * p.flatten()[0]
x = x + tot # tot will be 0, this does nothing.
return CachingEvalFunction.apply(x, m)
def caching_eval(*args):
"""
A memory-efficient way to evaluate a nn.Module (or function or lambda), that
recomputes the forward pass during the backward pass so we don't have to
store intermediate quantities in the graph.
Example:
m = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 10))
y = caching_eval(m, x)
This function will treat the first arg as a function and give it the
remaining args. If m is a lambda, you should not capture
any nn.Module or Tensor arguments in the lambda; instead you should make
them arguments to the lambda.
m must return a single item or a tuple of items; the items may be Tensors
or other types (Tensor will be treated specially).
This function returns a single element (probably a Tensor) if m returned
a single element; otherwise it returns a tuple.
"""
dummy_args = []
for arg in args:
# these dummy_args, the list of parameters, are not going to be
# used directly and no grad will be returned for them; the purpose of
# adding them is to make PyTorch think that a grad might be returned,
# so it doesn't assign a zero grad if the training loop does backward()
# with find_unused_args=True.
if isinstance(arg, nn.Module):
dummy_args = dummy_args + list(arg.parameters())
orig_num_args = len(args)
# args we give to the function: n
function_args = [ orig_num_args ] + list(args) + dummy_args
# This function returns a single element (probably a Tensor) or a tuple;
# it returns whatever
return CachingEvalFunction.apply(*function_args)
class RandomGradFunction(torch.autograd.Function):
@ -2082,11 +2151,38 @@ def _test_softmax():
assert torch.allclose(a.grad, b.grad)
def _test_caching_eval():
m = nn.Sequential(nn.Linear(10, 100, bias=False), nn.ReLU(), nn.Linear(100, 10, bias=False))
x = torch.randn(50, 10)
y_grad = torch.randn(50, 10)
x.requires_grad = True
y1 = m(x)
y1.backward(gradient=y_grad)
x_grad1 = x.grad
x.grad = None
weight_grad1a = m[0].weight.grad
weight_grad1b = m[2].weight.grad
m[0].weight.grad = None
m[2].weight.grad = None
m.zero_grad()
y2 = caching_eval(m, x)
assert torch.allclose(y1, y2)
y2.backward(gradient=y_grad)
assert torch.allclose(x.grad, x_grad1)
assert torch.allclose(m[0].weight.grad, weight_grad1a)
assert torch.allclose(m[2].weight.grad, weight_grad1b)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_caching_eval()
_test_softmax()
_test_whiten()
_test_max_eig()

View File

@ -1853,8 +1853,22 @@ class ConvNeXt(nn.Module):
prob=(0.025, 0.25),
grad_scale=0.01)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
else:
mask = None
return caching_eval(self.forward_internal, x, mask)
def forward_internal(self,
x: Tensor,
layer_skip_mask: Optional[Tensor] = None) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
@ -1867,11 +1881,8 @@ class ConvNeXt(nn.Module):
x = self.activation(x)
x = self.pointwise_conv2(x)
layerdrop_rate = float(self.layerdrop_rate)
if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
x = x * mask
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)