mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add (back) straight_through_rate, with rate 0.025; try to handle memory allocation failures in backprop better.
This commit is contained in:
parent
e4626a14b8
commit
6f5c4688ef
@ -683,6 +683,8 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
x, = ctx.saved_tensors
|
x, = ctx.saved_tensors
|
||||||
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
@ -717,6 +719,8 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
# (frame and dimension). later we can consider factored versions.
|
# (frame and dimension). later we can consider factored versions.
|
||||||
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||||
x_grad = x_grad_mod.to(x_grad.dtype)
|
x_grad = x_grad_mod.to(x_grad.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.")
|
||||||
|
|
||||||
return x_grad, None, None, None, None, None, None
|
return x_grad, None, None, None, None, None, None
|
||||||
|
|
||||||
@ -924,6 +928,8 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
x_grad: Tensor):
|
x_grad: Tensor):
|
||||||
x_orig, = ctx.saved_tensors
|
x_orig, = ctx.saved_tensors
|
||||||
w = ctx.module
|
w = ctx.module
|
||||||
|
|
||||||
|
try:
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x_detached = x_orig.to(torch.float32).detach()
|
x_detached = x_orig.to(torch.float32).detach()
|
||||||
@ -946,6 +952,10 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
(penalty_grad.norm() + 1.0e-20))
|
(penalty_grad.norm() + 1.0e-20))
|
||||||
penalty_grad = penalty_grad * scale
|
penalty_grad = penalty_grad * scale
|
||||||
return x_grad + penalty_grad.to(x_grad.dtype), None
|
return x_grad + penalty_grad.to(x_grad.dtype), None
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.")
|
||||||
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Whiten(nn.Module):
|
class Whiten(nn.Module):
|
||||||
|
|||||||
@ -427,9 +427,10 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||||
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate)
|
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate,
|
||||||
|
straight_through_rate=0.025)
|
||||||
# bypass_mid is bypass used in the middle of the layer.
|
# bypass_mid is bypass used in the middle of the layer.
|
||||||
self.bypass_mid = BypassModule(embed_dim)
|
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0.025)
|
||||||
|
|
||||||
|
|
||||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||||
@ -768,11 +769,13 @@ class BypassModule(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
skip_rate: FloatLike = 0.0,
|
skip_rate: FloatLike = 0.0,
|
||||||
|
straight_through_rate: FloatLike = 0.0,
|
||||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||||
scale_max: FloatLike = 1.0):
|
scale_max: FloatLike = 1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
self.skip_rate = copy.deepcopy(skip_rate)
|
self.skip_rate = copy.deepcopy(skip_rate)
|
||||||
|
self.straight_through_rate = copy.deepcopy(straight_through_rate)
|
||||||
self.scale_min = copy.deepcopy(scale_min)
|
self.scale_min = copy.deepcopy(scale_min)
|
||||||
self.scale_max = copy.deepcopy(scale_max)
|
self.scale_max = copy.deepcopy(scale_max)
|
||||||
|
|
||||||
@ -794,6 +797,11 @@ class BypassModule(nn.Module):
|
|||||||
ans = ans * mask
|
ans = ans * mask
|
||||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
||||||
# on which we have randomly chosen to do layer-skipping.
|
# on which we have randomly chosen to do layer-skipping.
|
||||||
|
straight_through_rate = float(self.straight_through_rate)
|
||||||
|
if straight_through_rate != 0.0:
|
||||||
|
mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate
|
||||||
|
ans = torch.maximum(ans, mask.to(ans.dtype))
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -826,7 +834,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(dim, downsample)
|
self.upsample = SimpleUpsample(dim, downsample)
|
||||||
self.out_combiner = BypassModule(dim)
|
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user