From 5fa8de5c05966ae8beed994abc27b1eacab4874e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 20 Dec 2022 13:51:08 +0800 Subject: [PATCH] Implement layerdrop per-sequence for convnext; lower, slower-decreasing layerdrop rate. --- .../pruned_transducer_stateless7/zipformer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 3172838af..1ff2b1ed9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1660,14 +1660,14 @@ class ConvNeXt(nn.Module): def __init__(self, channels: int, hidden_ratio: int = 4, - layerdrop_prob: FloatLike = None): + layerdrop_rate: FloatLike = None): super().__init__() kernel_size = 7 pad = (kernel_size - 1) // 2 hidden_channels = channels * hidden_ratio - if layerdrop_prob is None: - layerdrop_prob = ScheduledFloat((0.0, 0.2), (16000.0, 0.025)) - self.layerdrop_prob = layerdrop_prob + if layerdrop_rate is None: + layerdrop_rate = ScheduledFloat((0.0, 0.1), (20000.0, 0.01)) + self.layerdrop_rate = layerdrop_rate self.depthwise_conv = nn.Conv2d( in_channels=channels, @@ -1702,17 +1702,20 @@ class ConvNeXt(nn.Module): The returned value has the same shape as x. """ - if torch.jit.is_scripting() or (self.training and random.random() < float(self.layerdrop_prob)): - return x - bypass = x x = self.depthwise_conv(x) x = self.pointwise_conv1(x) x = self.hidden_balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) - return bypass + x + layerdrop_rate = float(self.layerdrop_rate) + if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate + x = x * mask + + return bypass + x