Add (back) straight_through_rate, with rate 0.025; try to handle memory allocation failures in backprop better.

This commit is contained in:
Daniel Povey 2023-04-30 15:19:34 +08:00
parent e4626a14b8
commit 6f5c4688ef
2 changed files with 68 additions and 50 deletions

View File

@ -683,6 +683,8 @@ class BalancerFunction(torch.autograd.Function):
x, = ctx.saved_tensors x, = ctx.saved_tensors
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = x.to(torch.float32) x = x.to(torch.float32)
@ -717,6 +719,8 @@ class BalancerFunction(torch.autograd.Function):
# (frame and dimension). later we can consider factored versions. # (frame and dimension). later we can consider factored versions.
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
x_grad = x_grad_mod.to(x_grad.dtype) x_grad = x_grad_mod.to(x_grad.dtype)
except Exception as e:
logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.")
return x_grad, None, None, None, None, None, None return x_grad, None, None, None, None, None, None
@ -924,6 +928,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
x_grad: Tensor): x_grad: Tensor):
x_orig, = ctx.saved_tensors x_orig, = ctx.saved_tensors
w = ctx.module w = ctx.module
try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x_detached = x_orig.to(torch.float32).detach() x_detached = x_orig.to(torch.float32).detach()
@ -946,6 +952,10 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
(penalty_grad.norm() + 1.0e-20)) (penalty_grad.norm() + 1.0e-20))
penalty_grad = penalty_grad * scale penalty_grad = penalty_grad * scale
return x_grad + penalty_grad.to(x_grad.dtype), None return x_grad + penalty_grad.to(x_grad.dtype), None
except Exception as e:
logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.")
return x_grad, None
class Whiten(nn.Module): class Whiten(nn.Module):

View File

@ -427,9 +427,10 @@ class Zipformer2EncoderLayer(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
# self.bypass implements layer skipping as well as bypass; see its default values. # self.bypass implements layer skipping as well as bypass; see its default values.
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate) self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate,
straight_through_rate=0.025)
# bypass_mid is bypass used in the middle of the layer. # bypass_mid is bypass used in the middle of the layer.
self.bypass_mid = BypassModule(embed_dim) self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025)
# skip probability for dynamic modules (meaning: anything but feedforward). # skip probability for dynamic modules (meaning: anything but feedforward).
@ -768,11 +769,13 @@ class BypassModule(nn.Module):
self, self,
embed_dim: int, embed_dim: int,
skip_rate: FloatLike = 0.0, skip_rate: FloatLike = 0.0,
straight_through_rate: FloatLike = 0.0,
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
scale_max: FloatLike = 1.0): scale_max: FloatLike = 1.0):
super().__init__() super().__init__()
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
self.skip_rate = copy.deepcopy(skip_rate) self.skip_rate = copy.deepcopy(skip_rate)
self.straight_through_rate = copy.deepcopy(straight_through_rate)
self.scale_min = copy.deepcopy(scale_min) self.scale_min = copy.deepcopy(scale_min)
self.scale_max = copy.deepcopy(scale_max) self.scale_max = copy.deepcopy(scale_max)
@ -794,6 +797,11 @@ class BypassModule(nn.Module):
ans = ans * mask ans = ans * mask
# now ans is of shape (batch_size, num_channels), and is zero for sequences # now ans is of shape (batch_size, num_channels), and is zero for sequences
# on which we have randomly chosen to do layer-skipping. # on which we have randomly chosen to do layer-skipping.
straight_through_rate = float(self.straight_through_rate)
if straight_through_rate != 0.0:
mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate
ans = torch.maximum(ans, mask.to(ans.dtype))
return ans return ans
def forward(self, def forward(self,
@ -826,7 +834,7 @@ class DownsampledZipformer2Encoder(nn.Module):
downsample, dropout) downsample, dropout)
self.encoder = encoder self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample) self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = BypassModule(dim) self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
def forward(self, def forward(self,