Add (back) straight_through_rate, with rate 0.025; try to handle memory allocation failures in backprop better.

This commit is contained in:
Daniel Povey 2023-04-30 15:19:34 +08:00
parent e4626a14b8
commit 6f5c4688ef
2 changed files with 68 additions and 50 deletions

View File

@ -683,40 +683,44 @@ class BalancerFunction(torch.autograd.Function):
x, = ctx.saved_tensors
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
mean = x.mean(dim=mean_dims, keepdim=True)
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
m = mean / stddev
# part of loss that relates to mean / stddev
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
try:
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32)
x = x.detach()
x.requires_grad = True
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
mean = x.mean(dim=mean_dims, keepdim=True)
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
# m_loss are violated we fix the RMS loss first.
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
r_loss = (rms_clamped / rms).log().abs()
m = mean / stddev
# part of loss that relates to mean / stddev
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
loss = (m_loss + r_loss)
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
# m_loss are violated we fix the RMS loss first.
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
r_loss = (rms_clamped / rms).log().abs()
loss.backward(gradient=torch.ones_like(loss))
loss_grad = x.grad
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
loss = (m_loss + r_loss)
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
loss.backward(gradient=torch.ones_like(loss))
loss_grad = x.grad
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
x_grad_float = x_grad.to(torch.float32)
# scale each element of loss_grad by the absolute value of the corresponding
# element of x_grad, which we view as a noisy estimate of its magnitude for that
# (frame and dimension). later we can consider factored versions.
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
x_grad = x_grad_mod.to(x_grad.dtype)
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
x_grad_float = x_grad.to(torch.float32)
# scale each element of loss_grad by the absolute value of the corresponding
# element of x_grad, which we view as a noisy estimate of its magnitude for that
# (frame and dimension). later we can consider factored versions.
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
x_grad = x_grad_mod.to(x_grad.dtype)
except Exception as e:
logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.")
return x_grad, None, None, None, None, None, None
@ -924,28 +928,34 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
x_grad: Tensor):
x_orig, = ctx.saved_tensors
w = ctx.module
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
metric = _whitening_metric(x_detached, w.num_groups)
try:
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach()
x_detached.requires_grad = True
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
metric = _whitening_metric(x_detached, w.num_groups)
if random.random() < 0.005 or __name__ == "__main__":
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
if metric < float(w.whitening_limit):
w.prob = w.min_prob
return x_grad, None
else:
w.prob = w.max_prob
metric.backward()
penalty_grad = x_detached.grad
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None
except Exception as e:
logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.")
return x_grad, None
if metric < float(w.whitening_limit):
w.prob = w.min_prob
return x_grad, None
else:
w.prob = w.max_prob
metric.backward()
penalty_grad = x_detached.grad
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
(penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None
class Whiten(nn.Module):

View File

@ -427,9 +427,10 @@ class Zipformer2EncoderLayer(nn.Module):
self.embed_dim = embed_dim
# self.bypass implements layer skipping as well as bypass; see its default values.
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate,
straight_through_rate=0.025)
# bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim)
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025)
# skip probability for dynamic modules (meaning: anything but feedforward).
@ -768,11 +769,13 @@ class BypassModule(nn.Module):
self,
embed_dim: int,
skip_rate: FloatLike = 0.0,
straight_through_rate: FloatLike = 0.0,
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
scale_max: FloatLike = 1.0):
super().__init__()
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
self.skip_rate = copy.deepcopy(skip_rate)
self.straight_through_rate = copy.deepcopy(straight_through_rate)
self.scale_min = copy.deepcopy(scale_min)
self.scale_max = copy.deepcopy(scale_max)
@ -794,6 +797,11 @@ class BypassModule(nn.Module):
ans = ans * mask
# now ans is of shape (batch_size, num_channels), and is zero for sequences
# on which we have randomly chosen to do layer-skipping.
straight_through_rate = float(self.straight_through_rate)
if straight_through_rate != 0.0:
mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate
ans = torch.maximum(ans, mask.to(ans.dtype))
return ans
def forward(self,
@ -826,7 +834,7 @@ class DownsampledZipformer2Encoder(nn.Module):
downsample, dropout)
self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = BypassModule(dim)
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
def forward(self,