Introduce scalar multiplication and change rules for updating gradient scale.

This commit is contained in:
Daniel Povey 2022-12-05 16:15:20 +08:00
parent 12fb2081b1
commit 7999dd0dbe
2 changed files with 22 additions and 2 deletions

View File

@ -910,7 +910,8 @@ def train_one_epoch(
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
if cur_grad_scale < 8.0 or (cur_grad_scale < 128.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
@ -947,7 +948,7 @@ def train_one_epoch(
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics and False:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,

View File

@ -1665,6 +1665,14 @@ class ConvolutionModule(nn.Module):
return x
class ScalarMultiply(nn.Module):
def __init__(self, scale: float):
super().__init__()
self.scale = scale
def forward(self, x):
return x * self.scale
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
@ -1703,13 +1711,21 @@ class Conv2dSubsampling(nn.Module):
assert in_channels >= 7
super().__init__()
# The ScalarMultiply modules are there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
ScalarMultiply(0.1),
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScalarMultiply(0.25),
ActivationBalancer(layer1_channels,
channel_dim=1),
DoubleSwish(),
@ -1757,6 +1773,9 @@ class Conv2dSubsampling(nn.Module):
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()