mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add (back) straight_through_rate, with rate 0.025; try to handle memory allocation failures in backprop better.
This commit is contained in:
parent
e4626a14b8
commit
6f5c4688ef
@ -683,40 +683,44 @@ class BalancerFunction(torch.autograd.Function):
|
||||
x, = ctx.saved_tensors
|
||||
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
||||
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = x.to(torch.float32)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
||||
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
|
||||
mean = x.mean(dim=mean_dims, keepdim=True)
|
||||
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||
|
||||
m = mean / stddev
|
||||
# part of loss that relates to mean / stddev
|
||||
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = x.to(torch.float32)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
||||
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
|
||||
mean = x.mean(dim=mean_dims, keepdim=True)
|
||||
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||
|
||||
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
||||
# m_loss are violated we fix the RMS loss first.
|
||||
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||
r_loss = (rms_clamped / rms).log().abs()
|
||||
m = mean / stddev
|
||||
# part of loss that relates to mean / stddev
|
||||
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
||||
|
||||
loss = (m_loss + r_loss)
|
||||
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
||||
# m_loss are violated we fix the RMS loss first.
|
||||
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||
r_loss = (rms_clamped / rms).log().abs()
|
||||
|
||||
loss.backward(gradient=torch.ones_like(loss))
|
||||
loss_grad = x.grad
|
||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||
loss = (m_loss + r_loss)
|
||||
|
||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||
loss.backward(gradient=torch.ones_like(loss))
|
||||
loss_grad = x.grad
|
||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||
|
||||
x_grad_float = x_grad.to(torch.float32)
|
||||
# scale each element of loss_grad by the absolute value of the corresponding
|
||||
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||
# (frame and dimension). later we can consider factored versions.
|
||||
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||
x_grad = x_grad_mod.to(x_grad.dtype)
|
||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||
|
||||
x_grad_float = x_grad.to(torch.float32)
|
||||
# scale each element of loss_grad by the absolute value of the corresponding
|
||||
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||
# (frame and dimension). later we can consider factored versions.
|
||||
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||
x_grad = x_grad_mod.to(x_grad.dtype)
|
||||
except Exception as e:
|
||||
logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.")
|
||||
|
||||
return x_grad, None, None, None, None, None, None
|
||||
|
||||
@ -924,28 +928,34 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
x_grad: Tensor):
|
||||
x_orig, = ctx.saved_tensors
|
||||
w = ctx.module
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
metric = _whitening_metric(x_detached, w.num_groups)
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x_detached = x_orig.to(torch.float32).detach()
|
||||
x_detached.requires_grad = True
|
||||
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
|
||||
metric = _whitening_metric(x_detached, w.num_groups)
|
||||
|
||||
if random.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
|
||||
|
||||
if metric < float(w.whitening_limit):
|
||||
w.prob = w.min_prob
|
||||
return x_grad, None
|
||||
else:
|
||||
w.prob = w.max_prob
|
||||
metric.backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||
(penalty_grad.norm() + 1.0e-20))
|
||||
penalty_grad = penalty_grad * scale
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None
|
||||
except Exception as e:
|
||||
logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.")
|
||||
return x_grad, None
|
||||
|
||||
if metric < float(w.whitening_limit):
|
||||
w.prob = w.min_prob
|
||||
return x_grad, None
|
||||
else:
|
||||
w.prob = w.max_prob
|
||||
metric.backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||
(penalty_grad.norm() + 1.0e-20))
|
||||
penalty_grad = penalty_grad * scale
|
||||
return x_grad + penalty_grad.to(x_grad.dtype), None
|
||||
|
||||
|
||||
class Whiten(nn.Module):
|
||||
|
||||
@ -427,9 +427,10 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
|
||||
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate,
|
||||
straight_through_rate=0.025)
|
||||
# bypass_mid is bypass used in the middle of the layer.
|
||||
self.bypass_mid = BypassModule(embed_dim)
|
||||
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025)
|
||||
|
||||
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||
@ -768,11 +769,13 @@ class BypassModule(nn.Module):
|
||||
self,
|
||||
embed_dim: int,
|
||||
skip_rate: FloatLike = 0.0,
|
||||
straight_through_rate: FloatLike = 0.0,
|
||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||
scale_max: FloatLike = 1.0):
|
||||
super().__init__()
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
self.skip_rate = copy.deepcopy(skip_rate)
|
||||
self.straight_through_rate = copy.deepcopy(straight_through_rate)
|
||||
self.scale_min = copy.deepcopy(scale_min)
|
||||
self.scale_max = copy.deepcopy(scale_max)
|
||||
|
||||
@ -794,6 +797,11 @@ class BypassModule(nn.Module):
|
||||
ans = ans * mask
|
||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
||||
# on which we have randomly chosen to do layer-skipping.
|
||||
straight_through_rate = float(self.straight_through_rate)
|
||||
if straight_through_rate != 0.0:
|
||||
mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate
|
||||
ans = torch.maximum(ans, mask.to(ans.dtype))
|
||||
|
||||
return ans
|
||||
|
||||
def forward(self,
|
||||
@ -826,7 +834,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
downsample, dropout)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(dim, downsample)
|
||||
self.out_combiner = BypassModule(dim)
|
||||
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
||||
|
||||
|
||||
def forward(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user