From 7999dd0dbe67537882b144c2ce97c700c1502641 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 5 Dec 2022 16:15:20 +0800 Subject: [PATCH] Introduce scalar multiplication and change rules for updating gradient scale. --- .../ASR/pruned_transducer_stateless7/train.py | 5 +++-- .../pruned_transducer_stateless7/zipformer.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b801beccf..60ab12a4f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -910,7 +910,8 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + + if cur_grad_scale < 8.0 or (cur_grad_scale < 128.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -947,7 +948,7 @@ def train_one_epoch( - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics and False: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 038da0136..14eb2ca94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1665,6 +1665,14 @@ class ConvolutionModule(nn.Module): return x +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/2 length). @@ -1703,13 +1711,21 @@ class Conv2dSubsampling(nn.Module): assert in_channels >= 7 super().__init__() + # The ScalarMultiply modules are there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + self.conv = nn.Sequential( + ScalarMultiply(0.1), nn.Conv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=(0, 1), # (time, freq) ), + ScalarMultiply(0.25), ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), @@ -1757,6 +1773,9 @@ class Conv2dSubsampling(nn.Module): """ # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) + # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite + # gradients. x = self.conv(x) # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size()