mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement caching evaluation for ConvNeXt
This commit is contained in:
parent
9242800d42
commit
b3527fe4ac
@ -264,60 +264,129 @@ class CutoffEstimator:
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
class CachingEvalFunction(torch.autograd.Function):
|
||||
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
||||
# that the backward path runs with the same autocast context as the forward pass.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor, m) -> Tensor:
|
||||
def forward(ctx, *args):
|
||||
"""
|
||||
m might be an nn.Module
|
||||
"""
|
||||
ctx.x_requires_grad = x.requires_grad
|
||||
ctx.m = m
|
||||
# we need any random numbers used in this evaluation and the next evaluation to be identical.
|
||||
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
|
||||
# that matters in the forward pass), e.g. there should be no dropout.
|
||||
ctx.random_state = random.getstate()
|
||||
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
||||
with torch.no_grad():
|
||||
y = m(x)
|
||||
ctx.save_for_backward(x, y)
|
||||
return y
|
||||
tot_num_args = len(args)
|
||||
orig_num_args = args[0]
|
||||
orig_args = args[1:1+orig_num_args]
|
||||
|
||||
ctx.num_dummy_args = tot_num_args - orig_num_args - 1
|
||||
|
||||
tensor_args = []
|
||||
non_tensor_args = []
|
||||
is_tensor_arg = []
|
||||
tensor_requires_grad = []
|
||||
for i in range(orig_num_args):
|
||||
arg = args[1 + i]
|
||||
is_tensor = isinstance(arg, torch.Tensor)
|
||||
if is_tensor:
|
||||
t = arg.detach()
|
||||
tensor_requires_grad.append(arg.requires_grad)
|
||||
tensor_args.append(t)
|
||||
|
||||
else:
|
||||
non_tensor_args.append(arg)
|
||||
is_tensor_arg.append(is_tensor)
|
||||
ctx.is_tensor_arg = is_tensor_arg # list of bool
|
||||
ctx.non_tensor_args = non_tensor_args
|
||||
ctx.tensor_requires_grad = tensor_requires_grad
|
||||
ctx.save_for_backward(*tensor_args)
|
||||
|
||||
# m is module, function or lambda.
|
||||
m = orig_args[0]
|
||||
# call m with the remaining elements of orig_args
|
||||
ans = m(*orig_args[1:])
|
||||
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, y_grad: Tensor):
|
||||
x, y = ctx.saved_tensors
|
||||
x = x.detach()
|
||||
x.requires_grad = ctx.x_requires_grad
|
||||
m = ctx.m # e.g. a nn.Module
|
||||
|
||||
random_state = random.getstate()
|
||||
# set the state to what we used in the 1st forward pass.
|
||||
random.setstate(ctx.random_state)
|
||||
def backward(ctx, *grads):
|
||||
with torch.enable_grad():
|
||||
y2 = m(x)
|
||||
assert torch.allclose(y, y2, atol=1.0e-02)
|
||||
# this call to backward() should create grads in the module's parameters
|
||||
y2.backward(gradient=y_grad)
|
||||
tensor_args = ctx.saved_tensors
|
||||
tensor_requires_grad = ctx.tensor_requires_grad
|
||||
non_tensor_args = ctx.non_tensor_args
|
||||
is_tensor_arg = ctx.is_tensor_arg
|
||||
args = []
|
||||
tensor_idx = 0
|
||||
non_tensor_idx = 0
|
||||
for b in ctx.is_tensor_arg:
|
||||
if b:
|
||||
t = tensor_args[tensor_idx]
|
||||
t.requires_grad = tensor_requires_grad[tensor_idx]
|
||||
args.append(t)
|
||||
tensor_idx += 1
|
||||
else:
|
||||
args.append(non_tensor_args[non_tensor_idx])
|
||||
non_tensor_idx += 1
|
||||
m = args[0]
|
||||
# ans should the same as the original ans.
|
||||
ans = m(*args[1:])
|
||||
if isinstance(ans, Tensor):
|
||||
ans = [ans]
|
||||
# keep only the tensors from ans.
|
||||
filtered_grads = []
|
||||
filtered_ans = []
|
||||
assert len(ans) == len(grads)
|
||||
for i, a in enumerate(ans):
|
||||
if isinstance(a, Tensor):
|
||||
filtered_ans.append(a)
|
||||
filtered_grads.append(grads[i])
|
||||
else:
|
||||
assert grads[i] is None
|
||||
|
||||
# restore the state from before we entered this function
|
||||
random.setstate(random_state)
|
||||
torch.autograd.backward(filtered_ans, filtered_grads)
|
||||
|
||||
return x.grad, None # x.grad will be None if x.requires_grad is False.
|
||||
returned_grads = [ a.grad if isinstance(a, Tensor) else None for a in args ]
|
||||
|
||||
return tuple([None] + returned_grads + [None] * ctx.num_dummy_args)
|
||||
|
||||
|
||||
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
||||
if m.training:
|
||||
# The purpose of this code is to make all parameters of m reachable in
|
||||
# the computation graph, so that if we give find_unused_parameters=True
|
||||
# to PyTorch's autograd code it does not assign them zero gradient.
|
||||
tot = 0.0
|
||||
for p in m.parameters():
|
||||
tot = tot + 0.0 * p.flatten()[0]
|
||||
x = x + tot # tot will be 0, this does nothing.
|
||||
return CachingEvalFunction.apply(x, m)
|
||||
def caching_eval(*args):
|
||||
"""
|
||||
A memory-efficient way to evaluate a nn.Module (or function or lambda), that
|
||||
recomputes the forward pass during the backward pass so we don't have to
|
||||
store intermediate quantities in the graph.
|
||||
|
||||
Example:
|
||||
m = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 10))
|
||||
y = caching_eval(m, x)
|
||||
This function will treat the first arg as a function and give it the
|
||||
remaining args. If m is a lambda, you should not capture
|
||||
any nn.Module or Tensor arguments in the lambda; instead you should make
|
||||
them arguments to the lambda.
|
||||
|
||||
m must return a single item or a tuple of items; the items may be Tensors
|
||||
or other types (Tensor will be treated specially).
|
||||
|
||||
This function returns a single element (probably a Tensor) if m returned
|
||||
a single element; otherwise it returns a tuple.
|
||||
"""
|
||||
dummy_args = []
|
||||
for arg in args:
|
||||
# these dummy_args, the list of parameters, are not going to be
|
||||
# used directly and no grad will be returned for them; the purpose of
|
||||
# adding them is to make PyTorch think that a grad might be returned,
|
||||
# so it doesn't assign a zero grad if the training loop does backward()
|
||||
# with find_unused_args=True.
|
||||
if isinstance(arg, nn.Module):
|
||||
dummy_args = dummy_args + list(arg.parameters())
|
||||
orig_num_args = len(args)
|
||||
|
||||
# args we give to the function: n
|
||||
function_args = [ orig_num_args ] + list(args) + dummy_args
|
||||
# This function returns a single element (probably a Tensor) or a tuple;
|
||||
# it returns whatever
|
||||
return CachingEvalFunction.apply(*function_args)
|
||||
|
||||
|
||||
|
||||
class RandomGradFunction(torch.autograd.Function):
|
||||
@ -2082,11 +2151,38 @@ def _test_softmax():
|
||||
assert torch.allclose(a.grad, b.grad)
|
||||
|
||||
|
||||
def _test_caching_eval():
|
||||
m = nn.Sequential(nn.Linear(10, 100, bias=False), nn.ReLU(), nn.Linear(100, 10, bias=False))
|
||||
|
||||
x = torch.randn(50, 10)
|
||||
y_grad = torch.randn(50, 10)
|
||||
x.requires_grad = True
|
||||
|
||||
y1 = m(x)
|
||||
y1.backward(gradient=y_grad)
|
||||
x_grad1 = x.grad
|
||||
x.grad = None
|
||||
weight_grad1a = m[0].weight.grad
|
||||
weight_grad1b = m[2].weight.grad
|
||||
m[0].weight.grad = None
|
||||
m[2].weight.grad = None
|
||||
m.zero_grad()
|
||||
|
||||
y2 = caching_eval(m, x)
|
||||
assert torch.allclose(y1, y2)
|
||||
y2.backward(gradient=y_grad)
|
||||
assert torch.allclose(x.grad, x_grad1)
|
||||
assert torch.allclose(m[0].weight.grad, weight_grad1a)
|
||||
assert torch.allclose(m[2].weight.grad, weight_grad1b)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_caching_eval()
|
||||
_test_softmax()
|
||||
_test_whiten()
|
||||
_test_max_eig()
|
||||
|
||||
@ -1853,8 +1853,22 @@ class ConvNeXt(nn.Module):
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||
else:
|
||||
mask = None
|
||||
return caching_eval(self.forward_internal, x, mask)
|
||||
|
||||
|
||||
def forward_internal(self,
|
||||
x: Tensor,
|
||||
layer_skip_mask: Optional[Tensor] = None) -> Tensor:
|
||||
"""
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
|
||||
@ -1867,11 +1881,8 @@ class ConvNeXt(nn.Module):
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||
x = x * mask
|
||||
if layer_skip_mask is not None:
|
||||
x = x * layer_skip_mask
|
||||
|
||||
x = bypass + x
|
||||
x = self.out_balancer(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user