mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce ConvNeXt parameters.
This commit is contained in:
parent
8d9301e225
commit
083e5474c4
@ -246,6 +246,7 @@ class CachingEvalFunction(torch.autograd.Function):
|
|||||||
# that matters in the forward pass), e.g. there should be no dropout.
|
# that matters in the forward pass), e.g. there should be no dropout.
|
||||||
ctx.random_state = random.getstate()
|
ctx.random_state = random.getstate()
|
||||||
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
||||||
|
with torch.no_grad():
|
||||||
y = m(x)
|
y = m(x)
|
||||||
ctx.save_for_backward(x, y)
|
ctx.save_for_backward(x, y)
|
||||||
return y
|
return y
|
||||||
@ -254,6 +255,7 @@ class CachingEvalFunction(torch.autograd.Function):
|
|||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, y_grad: Tensor):
|
def backward(ctx, y_grad: Tensor):
|
||||||
x, y = ctx.saved_tensors
|
x, y = ctx.saved_tensors
|
||||||
|
x = x.detach()
|
||||||
x.requires_grad = ctx.x_requires_grad
|
x.requires_grad = ctx.x_requires_grad
|
||||||
m = ctx.m # e.g. a nn.Module
|
m = ctx.m # e.g. a nn.Module
|
||||||
|
|
||||||
@ -273,6 +275,14 @@ class CachingEvalFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
||||||
|
if m.training:
|
||||||
|
# The purpose of this code is to make all parameters of m reachable in
|
||||||
|
# the computation graph, so that if we give find_unused_parameters=True
|
||||||
|
# to PyTorch's autograd code it does not assign them zero gradient.
|
||||||
|
tot = 0.0
|
||||||
|
for p in m.parameters():
|
||||||
|
tot = tot + 0.0 * p.flatten()[0]
|
||||||
|
x = x + tot # tot will be 0, this does nothing.
|
||||||
return CachingEvalFunction.apply(x, m)
|
return CachingEvalFunction.apply(x, m)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1767,7 +1767,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
out_channels: int,
|
out_channels: int,
|
||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 96,
|
||||||
dropout: FloatLike = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1829,7 +1829,6 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
))
|
))
|
||||||
|
|
||||||
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels),
|
||||||
ConvNeXt(layer3_channels),
|
|
||||||
ConvNeXt(layer3_channels),
|
ConvNeXt(layer3_channels),
|
||||||
BasicNorm(layer3_channels,
|
BasicNorm(layer3_channels,
|
||||||
channel_dim=1))
|
channel_dim=1))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user