From 083e5474c44db5a59d6e4b3534239a5c2d51013b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 00:21:04 +0800 Subject: [PATCH] Reduce ConvNeXt parameters. --- .../ASR/pruned_transducer_stateless7/scaling.py | 12 +++++++++++- .../ASR/pruned_transducer_stateless7/zipformer.py | 3 +-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 13d2d890e..672fd3465 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -246,7 +246,8 @@ class CachingEvalFunction(torch.autograd.Function): # that matters in the forward pass), e.g. there should be no dropout. ctx.random_state = random.getstate() # we are inside torch.no_grad() here, so the following won't create the computation graph. - y = m(x) + with torch.no_grad(): + y = m(x) ctx.save_for_backward(x, y) return y @@ -254,6 +255,7 @@ class CachingEvalFunction(torch.autograd.Function): @custom_bwd def backward(ctx, y_grad: Tensor): x, y = ctx.saved_tensors + x = x.detach() x.requires_grad = ctx.x_requires_grad m = ctx.m # e.g. a nn.Module @@ -273,6 +275,14 @@ class CachingEvalFunction(torch.autograd.Function): def caching_eval(x: Tensor, m: nn.Module) -> Tensor: + if m.training: + # The purpose of this code is to make all parameters of m reachable in + # the computation graph, so that if we give find_unused_parameters=True + # to PyTorch's autograd code it does not assign them zero gradient. + tot = 0.0 + for p in m.parameters(): + tot = tot + 0.0 * p.flatten()[0] + x = x + tot # tot will be 0, this does nothing. return CachingEvalFunction.apply(x, m) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5f49f220e..1108d076e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1767,7 +1767,7 @@ class Conv2dSubsampling(nn.Module): out_channels: int, layer1_channels: int = 8, layer2_channels: int = 32, - layer3_channels: int = 128, + layer3_channels: int = 96, dropout: FloatLike = 0.1, ) -> None: """ @@ -1829,7 +1829,6 @@ class Conv2dSubsampling(nn.Module): )) self.convnext2 = nn.Sequential(ConvNeXt(layer3_channels), - ConvNeXt(layer3_channels), ConvNeXt(layer3_channels), BasicNorm(layer3_channels, channel_dim=1))