mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce ConvNeXt parameters.
This commit is contained in:
parent
8d9301e225
commit
083e5474c4
@ -246,7 +246,8 @@ class CachingEvalFunction(torch.autograd.Function):
|
||||
# that matters in the forward pass), e.g. there should be no dropout.
|
||||
ctx.random_state = random.getstate()
|
||||
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
||||
y = m(x)
|
||||
with torch.no_grad():
|
||||
y = m(x)
|
||||
ctx.save_for_backward(x, y)
|
||||
return y
|
||||
|
||||
@ -254,6 +255,7 @@ class CachingEvalFunction(torch.autograd.Function):
|
||||
@custom_bwd
|
||||
def backward(ctx, y_grad: Tensor):
|
||||
x, y = ctx.saved_tensors
|
||||
x = x.detach()
|
||||
x.requires_grad = ctx.x_requires_grad
|
||||
m = ctx.m # e.g. a nn.Module
|
||||
|
||||
@ -273,6 +275,14 @@ class CachingEvalFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
||||
if m.training:
|
||||
# The purpose of this code is to make all parameters of m reachable in
|
||||
# the computation graph, so that if we give find_unused_parameters=True
|
||||
# to PyTorch's autograd code it does not assign them zero gradient.
|
||||
tot = 0.0
|
||||
for p in m.parameters():
|
||||
tot = tot + 0.0 * p.flatten()[0]
|
||||
x = x + tot # tot will be 0, this does nothing.
|
||||
return CachingEvalFunction.apply(x, m)
|
||||
|
||||
|
||||
|
||||
@ -1767,7 +1767,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
layer3_channels: int = 96,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
@ -1829,7 +1829,6 @@ class Conv2dSubsampling(nn.Module):
|
||||
))
|
||||
|
||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
ConvNeXt(layer3_channels),
|
||||
BasicNorm(layer3_channels,
|
||||
channel_dim=1))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user