Reduce ConvNeXt parameters.

This commit is contained in:
Daniel Povey 2022-12-16 00:21:04 +08:00
parent 8d9301e225
commit 083e5474c4
2 changed files with 12 additions and 3 deletions

View File

@ -246,7 +246,8 @@ class CachingEvalFunction(torch.autograd.Function):
# that matters in the forward pass), e.g. there should be no dropout.
ctx.random_state = random.getstate()
# we are inside torch.no_grad() here, so the following won't create the computation graph.
y = m(x)
with torch.no_grad():
y = m(x)
ctx.save_for_backward(x, y)
return y
@ -254,6 +255,7 @@ class CachingEvalFunction(torch.autograd.Function):
@custom_bwd
def backward(ctx, y_grad: Tensor):
x, y = ctx.saved_tensors
x = x.detach()
x.requires_grad = ctx.x_requires_grad
m = ctx.m # e.g. a nn.Module
@ -273,6 +275,14 @@ class CachingEvalFunction(torch.autograd.Function):
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
if m.training:
# The purpose of this code is to make all parameters of m reachable in
# the computation graph, so that if we give find_unused_parameters=True
# to PyTorch's autograd code it does not assign them zero gradient.
tot = 0.0
for p in m.parameters():
tot = tot + 0.0 * p.flatten()[0]
x = x + tot # tot will be 0, this does nothing.
return CachingEvalFunction.apply(x, m)

View File

@ -1767,7 +1767,7 @@ class Conv2dSubsampling(nn.Module):
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
layer3_channels: int = 96,
dropout: FloatLike = 0.1,
) -> None:
"""
@ -1829,7 +1829,6 @@ class Conv2dSubsampling(nn.Module):
))
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels),
ConvNeXt(layer3_channels),
BasicNorm(layer3_channels,
channel_dim=1))