mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add (back) straight_through_rate, with rate 0.025; try to handle memory allocation failures in backprop better.
This commit is contained in:
parent
e4626a14b8
commit
6f5c4688ef
@ -683,40 +683,44 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
x, = ctx.saved_tensors
|
x, = ctx.saved_tensors
|
||||||
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
x = x.detach()
|
|
||||||
x.requires_grad = True
|
|
||||||
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
|
||||||
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
|
|
||||||
mean = x.mean(dim=mean_dims, keepdim=True)
|
|
||||||
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
|
||||||
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
|
||||||
|
|
||||||
m = mean / stddev
|
try:
|
||||||
# part of loss that relates to mean / stddev
|
with torch.enable_grad():
|
||||||
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
x = x.detach()
|
||||||
|
x.requires_grad = True
|
||||||
|
mean_dims = [ i for i in range(x.ndim) if i != channel_dim ]
|
||||||
|
uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True)
|
||||||
|
mean = x.mean(dim=mean_dims, keepdim=True)
|
||||||
|
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||||
|
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||||
|
|
||||||
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
m = mean / stddev
|
||||||
# m_loss are violated we fix the RMS loss first.
|
# part of loss that relates to mean / stddev
|
||||||
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
|
||||||
r_loss = (rms_clamped / rms).log().abs()
|
|
||||||
|
|
||||||
loss = (m_loss + r_loss)
|
# put a much larger scale on the RMS-max-limit loss, so that if both it and the
|
||||||
|
# m_loss are violated we fix the RMS loss first.
|
||||||
|
rms_clamped = rms.clamp(min=min_rms, max=max_rms)
|
||||||
|
r_loss = (rms_clamped / rms).log().abs()
|
||||||
|
|
||||||
loss.backward(gradient=torch.ones_like(loss))
|
loss = (m_loss + r_loss)
|
||||||
loss_grad = x.grad
|
|
||||||
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
|
||||||
|
|
||||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
loss.backward(gradient=torch.ones_like(loss))
|
||||||
|
loss_grad = x.grad
|
||||||
|
loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20)
|
||||||
|
|
||||||
x_grad_float = x_grad.to(torch.float32)
|
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||||
# scale each element of loss_grad by the absolute value of the corresponding
|
|
||||||
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
x_grad_float = x_grad.to(torch.float32)
|
||||||
# (frame and dimension). later we can consider factored versions.
|
# scale each element of loss_grad by the absolute value of the corresponding
|
||||||
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||||
x_grad = x_grad_mod.to(x_grad.dtype)
|
# (frame and dimension). later we can consider factored versions.
|
||||||
|
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||||
|
x_grad = x_grad_mod.to(x_grad.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.")
|
||||||
|
|
||||||
return x_grad, None, None, None, None, None, None
|
return x_grad, None, None, None, None, None, None
|
||||||
|
|
||||||
@ -924,28 +928,34 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
x_grad: Tensor):
|
x_grad: Tensor):
|
||||||
x_orig, = ctx.saved_tensors
|
x_orig, = ctx.saved_tensors
|
||||||
w = ctx.module
|
w = ctx.module
|
||||||
with torch.enable_grad():
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
x_detached = x_orig.to(torch.float32).detach()
|
|
||||||
x_detached.requires_grad = True
|
|
||||||
|
|
||||||
metric = _whitening_metric(x_detached, w.num_groups)
|
try:
|
||||||
|
with torch.enable_grad():
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
x_detached = x_orig.to(torch.float32).detach()
|
||||||
|
x_detached.requires_grad = True
|
||||||
|
|
||||||
if random.random() < 0.005 or __name__ == "__main__":
|
metric = _whitening_metric(x_detached, w.num_groups)
|
||||||
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
|
|
||||||
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
|
if random.random() < 0.005 or __name__ == "__main__":
|
||||||
|
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||||
|
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
|
||||||
|
|
||||||
|
if metric < float(w.whitening_limit):
|
||||||
|
w.prob = w.min_prob
|
||||||
|
return x_grad, None
|
||||||
|
else:
|
||||||
|
w.prob = w.max_prob
|
||||||
|
metric.backward()
|
||||||
|
penalty_grad = x_detached.grad
|
||||||
|
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||||
|
(penalty_grad.norm() + 1.0e-20))
|
||||||
|
penalty_grad = penalty_grad * scale
|
||||||
|
return x_grad + penalty_grad.to(x_grad.dtype), None
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.")
|
||||||
|
return x_grad, None
|
||||||
|
|
||||||
if metric < float(w.whitening_limit):
|
|
||||||
w.prob = w.min_prob
|
|
||||||
return x_grad, None
|
|
||||||
else:
|
|
||||||
w.prob = w.max_prob
|
|
||||||
metric.backward()
|
|
||||||
penalty_grad = x_detached.grad
|
|
||||||
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
|
|
||||||
(penalty_grad.norm() + 1.0e-20))
|
|
||||||
penalty_grad = penalty_grad * scale
|
|
||||||
return x_grad + penalty_grad.to(x_grad.dtype), None
|
|
||||||
|
|
||||||
|
|
||||||
class Whiten(nn.Module):
|
class Whiten(nn.Module):
|
||||||
|
|||||||
@ -427,9 +427,10 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||||
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
|
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate,
|
||||||
|
straight_through_rate=0.025)
|
||||||
# bypass_mid is bypass used in the middle of the layer.
|
# bypass_mid is bypass used in the middle of the layer.
|
||||||
self.bypass_mid = BypassModule(embed_dim)
|
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025)
|
||||||
|
|
||||||
|
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||||
@ -768,11 +769,13 @@ class BypassModule(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
skip_rate: FloatLike = 0.0,
|
skip_rate: FloatLike = 0.0,
|
||||||
|
straight_through_rate: FloatLike = 0.0,
|
||||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||||
scale_max: FloatLike = 1.0):
|
scale_max: FloatLike = 1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
self.skip_rate = copy.deepcopy(skip_rate)
|
self.skip_rate = copy.deepcopy(skip_rate)
|
||||||
|
self.straight_through_rate = copy.deepcopy(straight_through_rate)
|
||||||
self.scale_min = copy.deepcopy(scale_min)
|
self.scale_min = copy.deepcopy(scale_min)
|
||||||
self.scale_max = copy.deepcopy(scale_max)
|
self.scale_max = copy.deepcopy(scale_max)
|
||||||
|
|
||||||
@ -794,6 +797,11 @@ class BypassModule(nn.Module):
|
|||||||
ans = ans * mask
|
ans = ans * mask
|
||||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
||||||
# on which we have randomly chosen to do layer-skipping.
|
# on which we have randomly chosen to do layer-skipping.
|
||||||
|
straight_through_rate = float(self.straight_through_rate)
|
||||||
|
if straight_through_rate != 0.0:
|
||||||
|
mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate
|
||||||
|
ans = torch.maximum(ans, mask.to(ans.dtype))
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -826,7 +834,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(dim, downsample)
|
self.upsample = SimpleUpsample(dim, downsample)
|
||||||
self.out_combiner = BypassModule(dim)
|
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user