mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement caching evaluation for ConvNeXt
This commit is contained in:
parent
9242800d42
commit
b3527fe4ac
@ -264,60 +264,129 @@ class CutoffEstimator:
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CachingEvalFunction(torch.autograd.Function):
|
class CachingEvalFunction(torch.autograd.Function):
|
||||||
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
||||||
# that the backward path runs with the same autocast context as the forward pass.
|
# that the backward path runs with the same autocast context as the forward pass.
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x: Tensor, m) -> Tensor:
|
def forward(ctx, *args):
|
||||||
"""
|
"""
|
||||||
m might be an nn.Module
|
m might be an nn.Module
|
||||||
"""
|
"""
|
||||||
ctx.x_requires_grad = x.requires_grad
|
tot_num_args = len(args)
|
||||||
ctx.m = m
|
orig_num_args = args[0]
|
||||||
# we need any random numbers used in this evaluation and the next evaluation to be identical.
|
orig_args = args[1:1+orig_num_args]
|
||||||
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
|
|
||||||
# that matters in the forward pass), e.g. there should be no dropout.
|
ctx.num_dummy_args = tot_num_args - orig_num_args - 1
|
||||||
ctx.random_state = random.getstate()
|
|
||||||
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
tensor_args = []
|
||||||
with torch.no_grad():
|
non_tensor_args = []
|
||||||
y = m(x)
|
is_tensor_arg = []
|
||||||
ctx.save_for_backward(x, y)
|
tensor_requires_grad = []
|
||||||
return y
|
for i in range(orig_num_args):
|
||||||
|
arg = args[1 + i]
|
||||||
|
is_tensor = isinstance(arg, torch.Tensor)
|
||||||
|
if is_tensor:
|
||||||
|
t = arg.detach()
|
||||||
|
tensor_requires_grad.append(arg.requires_grad)
|
||||||
|
tensor_args.append(t)
|
||||||
|
|
||||||
|
else:
|
||||||
|
non_tensor_args.append(arg)
|
||||||
|
is_tensor_arg.append(is_tensor)
|
||||||
|
ctx.is_tensor_arg = is_tensor_arg # list of bool
|
||||||
|
ctx.non_tensor_args = non_tensor_args
|
||||||
|
ctx.tensor_requires_grad = tensor_requires_grad
|
||||||
|
ctx.save_for_backward(*tensor_args)
|
||||||
|
|
||||||
|
# m is module, function or lambda.
|
||||||
|
m = orig_args[0]
|
||||||
|
# call m with the remaining elements of orig_args
|
||||||
|
ans = m(*orig_args[1:])
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, y_grad: Tensor):
|
def backward(ctx, *grads):
|
||||||
x, y = ctx.saved_tensors
|
|
||||||
x = x.detach()
|
|
||||||
x.requires_grad = ctx.x_requires_grad
|
|
||||||
m = ctx.m # e.g. a nn.Module
|
|
||||||
|
|
||||||
random_state = random.getstate()
|
|
||||||
# set the state to what we used in the 1st forward pass.
|
|
||||||
random.setstate(ctx.random_state)
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
y2 = m(x)
|
tensor_args = ctx.saved_tensors
|
||||||
assert torch.allclose(y, y2, atol=1.0e-02)
|
tensor_requires_grad = ctx.tensor_requires_grad
|
||||||
# this call to backward() should create grads in the module's parameters
|
non_tensor_args = ctx.non_tensor_args
|
||||||
y2.backward(gradient=y_grad)
|
is_tensor_arg = ctx.is_tensor_arg
|
||||||
|
args = []
|
||||||
|
tensor_idx = 0
|
||||||
|
non_tensor_idx = 0
|
||||||
|
for b in ctx.is_tensor_arg:
|
||||||
|
if b:
|
||||||
|
t = tensor_args[tensor_idx]
|
||||||
|
t.requires_grad = tensor_requires_grad[tensor_idx]
|
||||||
|
args.append(t)
|
||||||
|
tensor_idx += 1
|
||||||
|
else:
|
||||||
|
args.append(non_tensor_args[non_tensor_idx])
|
||||||
|
non_tensor_idx += 1
|
||||||
|
m = args[0]
|
||||||
|
# ans should the same as the original ans.
|
||||||
|
ans = m(*args[1:])
|
||||||
|
if isinstance(ans, Tensor):
|
||||||
|
ans = [ans]
|
||||||
|
# keep only the tensors from ans.
|
||||||
|
filtered_grads = []
|
||||||
|
filtered_ans = []
|
||||||
|
assert len(ans) == len(grads)
|
||||||
|
for i, a in enumerate(ans):
|
||||||
|
if isinstance(a, Tensor):
|
||||||
|
filtered_ans.append(a)
|
||||||
|
filtered_grads.append(grads[i])
|
||||||
|
else:
|
||||||
|
assert grads[i] is None
|
||||||
|
|
||||||
# restore the state from before we entered this function
|
torch.autograd.backward(filtered_ans, filtered_grads)
|
||||||
random.setstate(random_state)
|
|
||||||
|
|
||||||
return x.grad, None # x.grad will be None if x.requires_grad is False.
|
returned_grads = [ a.grad if isinstance(a, Tensor) else None for a in args ]
|
||||||
|
|
||||||
|
return tuple([None] + returned_grads + [None] * ctx.num_dummy_args)
|
||||||
|
|
||||||
|
|
||||||
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
def caching_eval(*args):
|
||||||
if m.training:
|
"""
|
||||||
# The purpose of this code is to make all parameters of m reachable in
|
A memory-efficient way to evaluate a nn.Module (or function or lambda), that
|
||||||
# the computation graph, so that if we give find_unused_parameters=True
|
recomputes the forward pass during the backward pass so we don't have to
|
||||||
# to PyTorch's autograd code it does not assign them zero gradient.
|
store intermediate quantities in the graph.
|
||||||
tot = 0.0
|
|
||||||
for p in m.parameters():
|
Example:
|
||||||
tot = tot + 0.0 * p.flatten()[0]
|
m = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 10))
|
||||||
x = x + tot # tot will be 0, this does nothing.
|
y = caching_eval(m, x)
|
||||||
return CachingEvalFunction.apply(x, m)
|
This function will treat the first arg as a function and give it the
|
||||||
|
remaining args. If m is a lambda, you should not capture
|
||||||
|
any nn.Module or Tensor arguments in the lambda; instead you should make
|
||||||
|
them arguments to the lambda.
|
||||||
|
|
||||||
|
m must return a single item or a tuple of items; the items may be Tensors
|
||||||
|
or other types (Tensor will be treated specially).
|
||||||
|
|
||||||
|
This function returns a single element (probably a Tensor) if m returned
|
||||||
|
a single element; otherwise it returns a tuple.
|
||||||
|
"""
|
||||||
|
dummy_args = []
|
||||||
|
for arg in args:
|
||||||
|
# these dummy_args, the list of parameters, are not going to be
|
||||||
|
# used directly and no grad will be returned for them; the purpose of
|
||||||
|
# adding them is to make PyTorch think that a grad might be returned,
|
||||||
|
# so it doesn't assign a zero grad if the training loop does backward()
|
||||||
|
# with find_unused_args=True.
|
||||||
|
if isinstance(arg, nn.Module):
|
||||||
|
dummy_args = dummy_args + list(arg.parameters())
|
||||||
|
orig_num_args = len(args)
|
||||||
|
|
||||||
|
# args we give to the function: n
|
||||||
|
function_args = [ orig_num_args ] + list(args) + dummy_args
|
||||||
|
# This function returns a single element (probably a Tensor) or a tuple;
|
||||||
|
# it returns whatever
|
||||||
|
return CachingEvalFunction.apply(*function_args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RandomGradFunction(torch.autograd.Function):
|
class RandomGradFunction(torch.autograd.Function):
|
||||||
@ -2082,11 +2151,38 @@ def _test_softmax():
|
|||||||
assert torch.allclose(a.grad, b.grad)
|
assert torch.allclose(a.grad, b.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_caching_eval():
|
||||||
|
m = nn.Sequential(nn.Linear(10, 100, bias=False), nn.ReLU(), nn.Linear(100, 10, bias=False))
|
||||||
|
|
||||||
|
x = torch.randn(50, 10)
|
||||||
|
y_grad = torch.randn(50, 10)
|
||||||
|
x.requires_grad = True
|
||||||
|
|
||||||
|
y1 = m(x)
|
||||||
|
y1.backward(gradient=y_grad)
|
||||||
|
x_grad1 = x.grad
|
||||||
|
x.grad = None
|
||||||
|
weight_grad1a = m[0].weight.grad
|
||||||
|
weight_grad1b = m[2].weight.grad
|
||||||
|
m[0].weight.grad = None
|
||||||
|
m[2].weight.grad = None
|
||||||
|
m.zero_grad()
|
||||||
|
|
||||||
|
y2 = caching_eval(m, x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
|
y2.backward(gradient=y_grad)
|
||||||
|
assert torch.allclose(x.grad, x_grad1)
|
||||||
|
assert torch.allclose(m[0].weight.grad, weight_grad1a)
|
||||||
|
assert torch.allclose(m[2].weight.grad, weight_grad1b)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
_test_caching_eval()
|
||||||
_test_softmax()
|
_test_softmax()
|
||||||
_test_whiten()
|
_test_whiten()
|
||||||
_test_max_eig()
|
_test_max_eig()
|
||||||
|
|||||||
@ -1853,8 +1853,22 @@ class ConvNeXt(nn.Module):
|
|||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if torch.jit.is_scripting() or not self.training:
|
||||||
|
return self.forward_internal(x)
|
||||||
|
layerdrop_rate = float(self.layerdrop_rate)
|
||||||
|
|
||||||
|
if layerdrop_rate != 0.0:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
return caching_eval(self.forward_internal, x, mask)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_internal(self,
|
||||||
|
x: Tensor,
|
||||||
|
layer_skip_mask: Optional[Tensor] = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||||
|
|
||||||
@ -1867,11 +1881,8 @@ class ConvNeXt(nn.Module):
|
|||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.pointwise_conv2(x)
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
layerdrop_rate = float(self.layerdrop_rate)
|
if layer_skip_mask is not None:
|
||||||
if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0:
|
x = x * layer_skip_mask
|
||||||
batch_size = x.shape[0]
|
|
||||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
|
||||||
x = x * mask
|
|
||||||
|
|
||||||
x = bypass + x
|
x = bypass + x
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user