mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make Whiten module update its prob every time
This commit is contained in:
parent
c097c13720
commit
05bcfd3b07
@ -1003,38 +1003,38 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx,
|
def forward(ctx,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
num_groups: int,
|
module: nn.Module) -> Tensor:
|
||||||
whitening_limit: float,
|
|
||||||
grad_scale: float,
|
|
||||||
name: Optional[str]) -> Tensor:
|
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
ctx.num_groups = num_groups
|
ctx.module = module
|
||||||
ctx.whitening_limit = whitening_limit
|
|
||||||
ctx.grad_scale = grad_scale
|
|
||||||
ctx.name = name
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx,
|
def backward(ctx,
|
||||||
x_grad: Tensor):
|
x_grad: Tensor):
|
||||||
x_orig, = ctx.saved_tensors
|
x_orig, = ctx.saved_tensors
|
||||||
|
w = ctx.module
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x_detached = x_orig.to(torch.float32).detach()
|
x_detached = x_orig.to(torch.float32).detach()
|
||||||
x_detached.requires_grad = True
|
x_detached.requires_grad = True
|
||||||
|
|
||||||
metric = _whitening_metric(x_detached, ctx.num_groups)
|
metric = _whitening_metric(x_detached, w.num_groups)
|
||||||
|
|
||||||
if random.random() < 0.005 or __name__ == "__main__":
|
if random.random() < 0.005 or __name__ == "__main__":
|
||||||
logging.info(f"Whitening: name={ctx.name}, num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
|
||||||
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}")
|
f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}")
|
||||||
|
|
||||||
(metric - ctx.whitening_limit).relu().backward()
|
if metric < float(w.whitening_limit):
|
||||||
|
w.prob = w.min_prob
|
||||||
|
return x_grad, None
|
||||||
|
else:
|
||||||
|
w.prob = w.max_prob
|
||||||
|
metric.backward()
|
||||||
penalty_grad = x_detached.grad
|
penalty_grad = x_detached.grad
|
||||||
scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() /
|
scale = w.grad_scale * (x_grad.to(torch.float32).norm() /
|
||||||
(penalty_grad.norm() + 1.0e-20))
|
(penalty_grad.norm() + 1.0e-20))
|
||||||
penalty_grad = penalty_grad * scale
|
penalty_grad = penalty_grad * scale
|
||||||
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None, None
|
return x_grad + penalty_grad.to(x_grad.dtype), None
|
||||||
|
|
||||||
|
|
||||||
class Whiten(nn.Module):
|
class Whiten(nn.Module):
|
||||||
@ -1101,21 +1101,7 @@ class Whiten(nn.Module):
|
|||||||
if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
|
if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
|
||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
else:
|
else:
|
||||||
whitening_limit = float(self.whitening_limit)
|
return WhiteningPenaltyFunction.apply(x, self)
|
||||||
if hasattr(self, 'min_prob') and random.random() < 0.25:
|
|
||||||
# occasionally switch between min_prob and max_prob, based on whether
|
|
||||||
# we are above or below the threshold.
|
|
||||||
if _whitening_metric(x.to(torch.float32), self.num_groups) > whitening_limit:
|
|
||||||
# there would be a change to the grad.
|
|
||||||
self.prob = self.max_prob
|
|
||||||
else:
|
|
||||||
self.prob = self.min_prob
|
|
||||||
|
|
||||||
return WhiteningPenaltyFunction.apply(x,
|
|
||||||
self.num_groups,
|
|
||||||
whitening_limit,
|
|
||||||
grad_scale,
|
|
||||||
self.name)
|
|
||||||
|
|
||||||
|
|
||||||
class WithLoss(torch.autograd.Function):
|
class WithLoss(torch.autograd.Function):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user